--- title: Trainer keywords: fastai sidebar: home_sidebar summary: "Implementation of PyTorch model trainer." description: "Implementation of PyTorch model trainer." nb_path: "nbs/trainers/trainers.trainer.ipynb" ---
{% raw %}
{% endraw %}

v1

{% raw %}
from recohut.datasets.movielens import ML1mRatingDataset

# models
from recohut.models.afm import AFM
from recohut.models.afn import AFN
from recohut.models.autoint import AutoInt
from recohut.models.dcn import DCN
from recohut.models.deepfm import DeepFM
from recohut.models.ffm import FFM
from recohut.models.fm import FM
from recohut.models.fnfm import FNFM
from recohut.models.fnn import FNN
from recohut.models.hofm import HOFM
from recohut.models.lr import LR
from recohut.models.ncf import NCF
from recohut.models.nfm import NFM
from recohut.models.ncf import NCF
from recohut.models.pnn import PNN
from recohut.models.wide_and_deep import WideAndDeep
from recohut.models.xdeepfm import xDeepFM
{% endraw %} {% raw %}
ds = ML1mRatingDataset(root='/content/ML1m', min_uc=10, min_sc=5)
Downloading http://files.grouplens.org/datasets/movielens/ml-1m.zip
Extracting /content/ML1m/raw/ml-1m.zip
Processing...
Done!
{% endraw %} {% raw %}
import torch
import os
import tqdm
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader
{% endraw %} {% raw %}
class Args:
    def __init__(self,
                 dataset='ml_1m',
                 model='wide_and_deep'
                 ):
        self.dataset = dataset
        self.model = model
        # dataset
        if dataset == 'ml_1m':
            self.dataset_root = '/content/ML1m'
            self.min_uc = 20
            self.min_sc = 20

        # model training
        self.device = 'cpu' # 'cuda:0'
        self.num_workers = 2
        self.batch_size = 256
        self.lr = 0.001
        self.weight_decay = 1e-6
        self.save_dir = '/content/chkpt'
        self.n_epochs = 2
        self.dropout = 0.2
        self.log_interval = 100

        # model architecture
        if model == 'wide_and_deep':
            self.embed_dim = 16
            self.mlp_dims = (16, 16)
        elif model == 'fm':
            self.embed_dim = 16
        elif model == 'ffm':
            self.embed_dim = 4
        elif model == 'hofm':
            self.embed_dim = 16
            self.order = 3
        elif model == 'fnn':
            self.embed_dim = 16
            self.mlp_dims = (16, 16)
        elif model == 'ipnn':
            self.embed_dim = 16
            self.mlp_dims = (16,)
            self.method = 'inner'
        elif model == 'opnn':
            self.embed_dim = 16
            self.mlp_dims = (16,)
            self.method = 'outer'
        elif model == 'dcn':
            self.embed_dim = 16
            self.num_layers = 3
            self.mlp_dims = (16, 16)
        elif model == 'nfm':
            self.embed_dim = 64
            self.mlp_dims = (64,)
            self.dropouts = (0.2, 0.2)
        elif model == 'ncf':
            self.embed_dim = 16
            self.mlp_dims = (16, 16)
        elif model == 'fnfm':
            self.embed_dim = 4
            self.mlp_dims = (64,)
            self.dropouts = (0.2, 0.2)
        elif model == 'deep_fm':
            self.embed_dim = 16
            self.mlp_dims = (16, 16)
        elif model == 'xdeep_fm':
            self.embed_dim = 16
            self.cross_layer_sizes = (16, 16)
            self.split_half = False
            self.mlp_dims = (16, 16)
        elif model == 'afm':
            self.embed_dim = 16
            self.attn_size = 16
            self.dropouts = (0.2, 0.2)
        elif model == 'autoint':
            self.embed_dim = 16
            self.atten_embed_dim = 64
            self.num_heads = 2
            self.num_layers = 3
            self.mlp_dims = (400, 400)
            self.dropouts = (0, 0, 0)
        elif model == 'afn':
            self.embed_dim = 16
            self.LNN_dim = 1500
            self.mlp_dims = (400, 400, 400)
            self.dropouts = (0, 0, 0)

    def get_dataset(self):
        if self.dataset == 'ml_1m':
            return ML1mRatingDataset(root = self.dataset_root,
                                     min_uc = self.min_uc,
                                     min_sc = self.min_sc
                                     )
    
    def get_model(self, field_dims, user_field_idx=None, item_field_idx=None):
        if self.model == 'wide_and_deep':
            return WideAndDeep(field_dims,
                               embed_dim=self.embed_dim,
                               mlp_dims = self.mlp_dims,
                               dropout = self.dropout
                               )
        elif self.model == 'fm':
            return FM(field_dims,
                      embed_dim = self.embed_dim
                      )
        elif self.model == 'lr':
            return LR(field_dims
                      )
        elif self.model == 'ffm':
            return FFM(field_dims,
                       embed_dim = self.embed_dim
                      )
        elif self.model == 'hofm':
            return HOFM(field_dims,
                        embed_dim = self.embed_dim,
                        order = self.order
                      )
        elif self.model == 'fnn':
            return FNN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropout = self.dropout
                      )
        elif self.model == 'ipnn':
            return PNN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       method = self.method,
                       dropout = self.dropout
                      )
        elif self.model == 'opnn':
            return PNN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       method = self.method,
                       dropout = self.dropout
                      )
        elif self.model == 'dcn':
            return DCN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       num_layers = self.num_layers,
                       dropout = self.dropout,
                      )
        elif self.model == 'nfm':
            return NFM(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropouts = self.dropouts,
                      )
        elif self.model == 'ncf':
            return NCF(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropout = self.dropout,
                       user_field_idx=user_field_idx,
                       item_field_idx=item_field_idx
                      )
        elif self.model == 'fnfm':
            return FNFM(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropouts = self.dropouts,
                      )
        elif self.model == 'deep_fm':
            return DeepFM(field_dims,
                          embed_dim = self.embed_dim,
                          mlp_dims = self.mlp_dims,
                          dropout = self.dropout,
                      )
        elif self.model == 'xdeep_fm':
            return xDeepFM(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropout = self.dropout,
                       cross_layer_sizes = self.cross_layer_sizes,
                       split_half = self.split_half,
                      )
        elif self.model == 'afm':
            return AFM(field_dims,
                       embed_dim = self.embed_dim,
                       dropouts = self.dropouts,
                       attn_size = self.attn_size,
                      )
        elif self.model == 'autoint':
            return AutoInt(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropouts = self.dropouts,
                       atten_embed_dim = self.atten_embed_dim,
                       num_heads = self.num_heads,
                       num_layers = self.num_layers,
                      )
        elif self.model == 'afn':
            return AFN(field_dims,
                       embed_dim = self.embed_dim,
                       mlp_dims = self.mlp_dims,
                       dropouts = self.dropouts,
                       LNN_dim = self.LNN_dim,
                      )
{% endraw %} {% raw %}
class EarlyStopper(object):

    def __init__(self, num_trials, save_path):
        self.num_trials = num_trials
        self.trial_counter = 0
        self.best_accuracy = 0
        self.save_path = save_path

    def is_continuable(self, model, accuracy):
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.trial_counter = 0
            torch.save(model, self.save_path)
            return True
        elif self.trial_counter + 1 < self.num_trials:
            self.trial_counter += 1
            return True
        else:
            return False
{% endraw %} {% raw %}
class Trainer:
    def __init__(self, args):
        device = torch.device(args.device)
        # dataset
        dataset = args.get_dataset()
        # model
        model = args.get_model(dataset.field_dims,
                               user_field_idx = dataset.user_field_idx,
                               item_field_idx = dataset.item_field_idx)
        model = model.to(device)
        model_name = type(model).__name__
        # data split
        train_length = int(len(dataset) * 0.8)
        valid_length = int(len(dataset) * 0.1)
        test_length = len(dataset) - train_length - valid_length
        # data loader
        train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
            dataset, (train_length, valid_length, test_length))
        train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        valid_data_loader = DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        # handlers
        criterion = torch.nn.BCELoss()
        optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        os.makedirs(args.save_dir, exist_ok=True)
        early_stopper = EarlyStopper(num_trials=2, save_path=f'{args.save_dir}/{model_name}.pt')
        # # scheduler
        # # ref - https://github.com/sparsh-ai/stanza/blob/7961a0a00dc06b9b28b71954b38181d6a87aa803/trainer/bert.py#L36
        # import torch.optim as optim
        # if args.enable_lr_schedule:
        #     if args.enable_lr_warmup:
        #         self.lr_scheduler = self.get_linear_schedule_with_warmup(
        #             optimizer, args.warmup_steps, len(train_data_loader) * self.n_epochs)
        #     else:
        #         self.lr_scheduler = optim.lr_scheduler.StepLR(
        #             optimizer, step_size=args.decay_step, gamma=args.gamma)
        # training
        for epoch_i in range(args.n_epochs):
            self._train(model, optimizer, train_data_loader, criterion, device)
            auc = self._test(model, valid_data_loader, device)
            print('epoch:', epoch_i, 'validation: auc:', auc)
            if not early_stopper.is_continuable(model, auc):
                print(f'validation: best auc: {early_stopper.best_accuracy}')
                break
        auc = self._test(model, test_data_loader, device)
        print(f'test auc: {auc}')

    @staticmethod
    def _train(model, optimizer, data_loader, criterion, device, log_interval=100):
        model.train()
        total_loss = 0
        tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
        for i, (fields, target) in enumerate(tk0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            loss = criterion(y, target.float())
            model.zero_grad()
            loss.backward()
            # self.clip_gradients(5)
            optimizer.step()
            # if self.args.enable_lr_schedule:
            #     self.lr_scheduler.step()
            total_loss += loss.item()
            if (i + 1) % log_interval == 0:
                tk0.set_postfix(loss=total_loss / log_interval)
                total_loss = 0
    
    @staticmethod
    def _test(model, data_loader, device):
        model.eval()
        targets, predicts = list(), list()
        with torch.no_grad():
            for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
                fields, target = fields.to(device), target.to(device)
                y = model(fields)
                targets.extend(target.tolist())
                predicts.extend(y.tolist())
        return roc_auc_score(targets, predicts)

    # def clip_gradients(self, limit=5):
    #     """
    #     Reference:
    #         1. https://github.com/sparsh-ai/stanza/blob/7961a0a00dc06b9b28b71954b38181d6a87aa803/trainer/bert.py#L175
    #     """
    #     for p in self.model.parameters():
    #         nn.utils.clip_grad_norm_(p, 5)

    # def _create_optimizer(self):
    #     args = self.args
    #     param_optimizer = list(self.model.named_parameters())
    #     no_decay = ['bias', 'layer_norm']
    #     optimizer_grouped_parameters = [
    #         {
    #             'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
    #             'weight_decay': args.weight_decay,
    #         },
    #         {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    #     ]
    #     if args.optimizer.lower() == 'adamw':
    #         return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
    #     elif args.optimizer.lower() == 'adam':
    #         return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)
    #     elif args.optimizer.lower() == 'sgd':
    #         return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
    #     else:
    #         raise ValueError

    # def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    #     # based on hugging face get_linear_schedule_with_warmup
    #     def lr_lambda(current_step: int):
    #         if current_step < num_warmup_steps:
    #             return float(current_step) / float(max(1, num_warmup_steps))
    #         return max(
    #             0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
    #         )

    #     return LambdaLR(optimizer, lr_lambda, last_epoch)
{% endraw %} {% raw %}
models = [
          'wide_and_deep',
          'fm',
          'lr',
          'ffm',
          'hofm',
          'fnn',
          'ipnn',
          'opnn',
          'dcn',
          'nfm',
          'ncf',
          'fnfm',
          'deep_fm',
          'xdeep_fm',
          'afm',
        #   'autoint',
        #   'afn'
          ]

for model in models:
    args = Args(model=model)
    trainer = Trainer(args)
Processing...
Done!
100%|██████████| 3126/3126 [00:23<00:00, 135.91it/s, loss=0.57]
100%|██████████| 391/391 [00:01<00:00, 252.62it/s]
epoch: 0 validation: auc: 0.7781601005135064
100%|██████████| 3126/3126 [00:22<00:00, 137.03it/s, loss=0.557]
100%|██████████| 391/391 [00:01<00:00, 259.00it/s]
epoch: 1 validation: auc: 0.7842454181872189
100%|██████████| 391/391 [00:01<00:00, 261.87it/s]
Processing...
test auc: 0.783773499847308
Done!
100%|██████████| 3126/3126 [00:15<00:00, 200.85it/s, loss=0.587]
100%|██████████| 391/391 [00:01<00:00, 277.38it/s]
epoch: 0 validation: auc: 0.7511323978391329
100%|██████████| 3126/3126 [00:15<00:00, 203.24it/s, loss=0.542]
100%|██████████| 391/391 [00:01<00:00, 286.19it/s]
epoch: 1 validation: auc: 0.7852232398637453
100%|██████████| 391/391 [00:01<00:00, 286.62it/s]
Processing...
test auc: 0.7851983970544512
Done!
100%|██████████| 3126/3126 [00:12<00:00, 243.01it/s, loss=0.713]
100%|██████████| 391/391 [00:01<00:00, 290.07it/s]
epoch: 0 validation: auc: 0.606845663039941
100%|██████████| 3126/3126 [00:12<00:00, 243.68it/s, loss=0.625]
100%|██████████| 391/391 [00:01<00:00, 290.60it/s]
epoch: 1 validation: auc: 0.6962495583229628
100%|██████████| 391/391 [00:01<00:00, 280.21it/s]
Processing...
test auc: 0.6917994954031111
Done!
100%|██████████| 3126/3126 [00:16<00:00, 189.89it/s, loss=0.639]
100%|██████████| 391/391 [00:01<00:00, 275.44it/s]
epoch: 0 validation: auc: 0.6956660360854087
100%|██████████| 3126/3126 [00:16<00:00, 190.43it/s, loss=0.559]
100%|██████████| 391/391 [00:01<00:00, 279.32it/s]
epoch: 1 validation: auc: 0.769259926201433
100%|██████████| 391/391 [00:01<00:00, 275.84it/s]
Processing...
test auc: 0.7694256825177728
Done!
100%|██████████| 3126/3126 [00:23<00:00, 135.66it/s, loss=0.585]
100%|██████████| 391/391 [00:01<00:00, 229.47it/s]
epoch: 0 validation: auc: 0.7508361070441243
100%|██████████| 3126/3126 [00:23<00:00, 135.52it/s, loss=0.538]
100%|██████████| 391/391 [00:01<00:00, 229.00it/s]
epoch: 1 validation: auc: 0.7867336507526798
100%|██████████| 391/391 [00:01<00:00, 226.59it/s]
Processing...
test auc: 0.7849653473859624
Done!
100%|██████████| 3126/3126 [00:21<00:00, 146.97it/s, loss=0.554]
100%|██████████| 391/391 [00:01<00:00, 272.83it/s]
epoch: 0 validation: auc: 0.7899586086314532
100%|██████████| 3126/3126 [00:21<00:00, 148.00it/s, loss=0.544]
100%|██████████| 391/391 [00:01<00:00, 268.07it/s]
epoch: 1 validation: auc: 0.7938707366151592
100%|██████████| 391/391 [00:01<00:00, 267.69it/s]
Processing...
test auc: 0.7935777015287597
Done!
100%|██████████| 3126/3126 [00:20<00:00, 151.01it/s, loss=0.55]
100%|██████████| 391/391 [00:01<00:00, 253.01it/s]
epoch: 0 validation: auc: 0.7901787607777198
100%|██████████| 3126/3126 [00:20<00:00, 151.50it/s, loss=0.536]
100%|██████████| 391/391 [00:01<00:00, 258.53it/s]
epoch: 1 validation: auc: 0.7958062417181883
100%|██████████| 391/391 [00:01<00:00, 265.78it/s]
Processing...
test auc: 0.7959379435427811
Done!
100%|██████████| 3126/3126 [00:21<00:00, 144.98it/s, loss=0.548]
100%|██████████| 391/391 [00:01<00:00, 256.43it/s]
epoch: 0 validation: auc: 0.7943316704845618
100%|██████████| 3126/3126 [00:21<00:00, 145.26it/s, loss=0.53]
100%|██████████| 391/391 [00:01<00:00, 252.76it/s]
epoch: 1 validation: auc: 0.8027591784990165
100%|██████████| 391/391 [00:01<00:00, 259.79it/s]
Processing...
test auc: 0.8016146552653354
Done!
100%|██████████| 3126/3126 [00:26<00:00, 116.37it/s, loss=0.537]
100%|██████████| 391/391 [00:01<00:00, 240.12it/s]
epoch: 0 validation: auc: 0.7898151214837668
100%|██████████| 3126/3126 [00:26<00:00, 116.92it/s, loss=0.527]
100%|██████████| 391/391 [00:01<00:00, 239.57it/s]
epoch: 1 validation: auc: 0.7955138244674892
100%|██████████| 391/391 [00:01<00:00, 240.84it/s]
Processing...
test auc: 0.7964998271099959
Done!
100%|██████████| 3126/3126 [00:22<00:00, 138.66it/s, loss=0.586]
100%|██████████| 391/391 [00:01<00:00, 252.66it/s]
epoch: 0 validation: auc: 0.7631548463451637
100%|██████████| 3126/3126 [00:22<00:00, 137.08it/s, loss=0.551]
100%|██████████| 391/391 [00:01<00:00, 251.84it/s]
epoch: 1 validation: auc: 0.7752154803420491
100%|██████████| 391/391 [00:01<00:00, 252.42it/s]
Processing...
test auc: 0.7727792981788815
Done!
100%|██████████| 3126/3126 [00:23<00:00, 132.24it/s, loss=0.554]
100%|██████████| 391/391 [00:01<00:00, 248.61it/s]
epoch: 0 validation: auc: 0.7876433331502086
100%|██████████| 3126/3126 [00:23<00:00, 132.00it/s, loss=0.543]
100%|██████████| 391/391 [00:01<00:00, 249.84it/s]
epoch: 1 validation: auc: 0.7923030405914255
100%|██████████| 391/391 [00:01<00:00, 257.83it/s]
Processing...
test auc: 0.7930787548185895
Done!
100%|██████████| 3126/3126 [00:23<00:00, 133.99it/s, loss=0.61]
100%|██████████| 391/391 [00:01<00:00, 250.20it/s]
epoch: 0 validation: auc: 0.7376150945998978
100%|██████████| 3126/3126 [00:23<00:00, 135.18it/s, loss=0.583]
100%|██████████| 391/391 [00:01<00:00, 246.25it/s]
epoch: 1 validation: auc: 0.7583206065924306
100%|██████████| 391/391 [00:01<00:00, 245.77it/s]
Processing...
test auc: 0.7594084947700983
Done!
100%|██████████| 3126/3126 [00:24<00:00, 127.45it/s, loss=0.569]
100%|██████████| 391/391 [00:01<00:00, 244.49it/s]
epoch: 0 validation: auc: 0.7806048647711028
100%|██████████| 3126/3126 [00:24<00:00, 128.70it/s, loss=0.554]
100%|██████████| 391/391 [00:01<00:00, 246.00it/s]
epoch: 1 validation: auc: 0.7857091265544482
100%|██████████| 391/391 [00:01<00:00, 245.62it/s]
Processing...
test auc: 0.7857843263334994
Done!
100%|██████████| 3126/3126 [00:30<00:00, 103.68it/s, loss=0.558]
100%|██████████| 391/391 [00:01<00:00, 219.06it/s]
epoch: 0 validation: auc: 0.7814674364890849
100%|██████████| 3126/3126 [00:29<00:00, 104.41it/s, loss=0.539]
100%|██████████| 391/391 [00:01<00:00, 214.94it/s]
epoch: 1 validation: auc: 0.7899837530572655
100%|██████████| 391/391 [00:01<00:00, 216.12it/s]
Processing...
test auc: 0.7863345464272122
Done!
100%|██████████| 3126/3126 [00:23<00:00, 133.32it/s, loss=0.606]
100%|██████████| 391/391 [00:01<00:00, 244.59it/s]
epoch: 0 validation: auc: 0.7590887701790624
100%|██████████| 3126/3126 [00:23<00:00, 134.06it/s, loss=0.576]
100%|██████████| 391/391 [00:01<00:00, 247.06it/s]
epoch: 1 validation: auc: 0.7820711568875622
100%|██████████| 391/391 [00:01<00:00, 247.66it/s]
test auc: 0.7835448236219698

{% endraw %} {% raw %}
models = [
          'autoint',
          'afn'
          ]

for model in models:
    args = Args(model=model)
    trainer = Trainer(args)
Processing...
Done!
100%|██████████| 3126/3126 [00:43<00:00, 72.44it/s, loss=0.551]
100%|██████████| 391/391 [00:02<00:00, 171.82it/s]
epoch: 0 validation: auc: 0.7838440329134869
100%|██████████| 3126/3126 [00:42<00:00, 73.24it/s, loss=0.532]
100%|██████████| 391/391 [00:02<00:00, 172.49it/s]
epoch: 1 validation: auc: 0.7924653055551055
100%|██████████| 391/391 [00:02<00:00, 169.07it/s]
Processing...
test auc: 0.7935854845577758
Done!
100%|██████████| 3126/3126 [01:24<00:00, 37.12it/s, loss=0.564]
100%|██████████| 391/391 [00:03<00:00, 107.76it/s]
epoch: 0 validation: auc: 0.7796980126749351
100%|██████████| 3126/3126 [01:23<00:00, 37.50it/s, loss=0.547]
100%|██████████| 391/391 [00:03<00:00, 108.16it/s]
epoch: 1 validation: auc: 0.7879478169612124
100%|██████████| 391/391 [00:03<00:00, 108.09it/s]
test auc: 0.7893059350190452

{% endraw %} {% raw %}
!tree --du -h -C /content/chkpt
/content/chkpt
├── [669K]  AFM.pt
├── [ 39M]  AFN.pt
├── [1.5M]  AutoInt.pt
├── [640K]  DCN.pt
├── [676K]  DeepFM.pt
├── [355K]  FFM.pt
├── [666K]  FM.pt
├── [363K]  FNFM.pt
├── [636K]  FNN.pt
├── [1.3M]  HOFM.pt
├── [ 41K]  LR.pt
├── [636K]  NCF.pt
├── [2.5M]  NFM.pt
├── [1.2M]  PNN.pt
├── [676K]  WideAndDeep.pt
└── [682K]  xDeepFM.pt

  51M used in 0 directories, 16 files
{% endraw %}
{% raw %}
!pip install -q wandb
{% endraw %} {% raw %}
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F

import os
import copy
import random
from pathlib import Path
from collections import defaultdict

from argparse import Namespace
from joblib import dump, load
from tqdm import tqdm

import wandb
from torch.utils.data import DataLoader as dl
{% endraw %} {% raw %}
class RecsysDataset(torch.utils.data.Dataset):

    def __init__(self,df,usr_dict=None,mov_dict=None):
        self.df = df
        self.usr_dict = usr_dict
        self.mov_dict = mov_dict

    def __getitem__(self,index):
        if self.usr_dict and self.mov_dict:
            return [self.usr_dict[int(self.df.iloc[index]['user_id'])],self.mov_dict[int(self.df.iloc[index]['movie_id'])]],self.df.iloc[index]['rating']
        else:
            return [int(self.df.iloc[index]['user_id']-1),int(self.df.iloc[index]['movie_id']-1)],self.df.iloc[index]['rating']
            
    def __len__(self):
        return len(self.df)


sample = pd.DataFrame({'user_id':[1,2,3,2,2,3,2,2],'movie_id':[1,2,3,3,3,2,1,1],'rating':[2.0,1.0,4.0,5.0,1.3,3.5,3.0,4.5]})
trn_ids = random.sample(range(8),4,)
valid_ids = [i for i in range(8) if i not in trn_ids]

sample_trn,sample_vld = copy.deepcopy(sample.iloc[trn_ids].reset_index()),copy.deepcopy(sample.iloc[valid_ids].reset_index())

sample_vld = RecsysDataset(sample_vld)
sample_trn = RecsysDataset(sample_trn)

train_loader = dl(sample_trn, batch_size=2, shuffle=True)
valid_loader = dl(sample_vld, batch_size=2, shuffle=True)
{% endraw %} {% raw %}
class NCF(nn.Module):
    
    def __init__(self,user_sz,item_sz,embd_sz,dropout_fac,min_r=0.0,max_r=5.0,alpha=0.5,with_variable_alpha=False):
        super().__init__()
        self.dropout_fac = dropout_fac
        self.user_embd_mtrx = nn.Embedding(user_sz,embd_sz)
        self.item_embd_mtrx = nn.Embedding(item_sz,embd_sz)
        #bias = torch.zeros(size=(user_sz, 1), requires_grad=True)
        self.h =  nn.Linear(embd_sz,1)
        self.fst_lyr = nn.Linear(embd_sz*2,embd_sz)
        self.snd_lyr = nn.Linear(embd_sz,embd_sz//2)
        self.thrd_lyr = nn.Linear(embd_sz//2,embd_sz//4)
        self.out_lyr = nn.Linear(embd_sz//4,1)
        self.alpha = torch.tensor(alpha)
        self.min_r,self.max_r = min_r,max_r
        if with_variable_alpha:
            self.alpha = torch.tensor(alpha,requires_grad=True)
        
    def forward(self,x):
        user_emd = self.user_embd_mtrx(x[0])
        item_emd = self.item_embd_mtrx(x[-1])
        #hadamard-product
        gmf = user_emd*item_emd
        gmf = self.h(gmf)
        
        
        mlp = torch.cat([user_emd,item_emd],dim=-1)
        mlp = self.out_lyr(F.relu(self.thrd_lyr(F.relu(self.snd_lyr(F.dropout(F.relu(self.fst_lyr(mlp)),p=self.dropout_fac))))))
        fac = torch.clip(self.alpha,min=0.0,max=1.0)
        out = fac*gmf+ (1-fac)*mlp
        out = torch.clip(out,min=self.min_r,max=self.max_r)
        return out
{% endraw %} {% raw %}
model = NCF(3,3,4,0.5)
for u,r in train_loader:
    #user,item = u
    print(f'user:{u[0]},item:{u[-1]} and rating:{r}')
    #print(u)
    out = model(u)
    print(f'output of the network=> out:{out},shape:{out.shape}')
    break
user:tensor([2, 1]),item:tensor([2, 0]) and rating:tensor([4.0000, 4.5000], dtype=torch.float64)
output of the network=> out:tensor([[0.4322],
        [0.5724]], grad_fn=<ClampBackward1>),shape:torch.Size([2, 1])
{% endraw %} {% raw %}
class Trainer(object):
    def __init__(self, model, device,loss_fn=None, optimizer=None, scheduler=None,artifacts_loc=None,exp_tracker=None):

        # Set params
        self.model = model
        self.device = device
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.store_loc = artifacts_loc
        self.exp_tracker = exp_tracker

    def train_step(self, dataloader):
        """Train step."""
        # Set model to train mode
        self.model.train()
        loss = 0.0

        # Iterate over train batches
        for i, batch in enumerate(dataloader):
            #batch = [item.to(self.device) for item in batch]  # Set device
            inputs,targets = batch
            inputs = [item.to(self.device) for item in inputs]
            targets = targets.to(self.device)
            #inputs, targets = batch[:-1], batch[-1]
            #import pdb;pdb.set_trace()
            self.optimizer.zero_grad()  # Reset gradients
            z = self.model(inputs)  # Forward pass
            targets = targets.reshape(z.shape)
            J = self.loss_fn(z.float(), targets.float())  # Define loss
            J.backward()  # Backward pass
            self.optimizer.step()  # Update weights

            # Cumulative Metrics
            loss += (J.detach().item() - loss) / (i + 1)

        return loss

    def eval_step(self, dataloader):
        """Validation or test step."""
        # Set model to eval mode
        self.model.eval()
        loss = 0.0
        y_trues, y_probs = [], []

        # Iterate over val batches
        with torch.inference_mode():
            for i, batch in enumerate(dataloader):
                inputs,y_true = batch
                inputs = [item.to(self.device) for item in inputs]
                y_true = y_true.to(self.device).float()

                # Step
                z = self.model(inputs).float()  # Forward pass
                y_true = y_true.reshape(z.shape)
                J = self.loss_fn(z, y_true).item()

                # Cumulative Metrics
                loss += (J - loss) / (i + 1)

                # Store outputs
                y_prob = z.cpu().numpy()
                y_probs.extend(y_prob)
                y_trues.extend(y_true.cpu().numpy())

        return loss, np.vstack(y_trues), np.vstack(y_probs)

    def predict_step(self, dataloader):
        """Prediction step."""
        # Set model to eval mode
        self.model.eval()
        y_probs = []

        # Iterate over val batches
        with torch.inference_mode():
            for i, batch in enumerate(dataloader):

                # Forward pass w/ inputs
                inputs, targets = batch
                z = self.model(inputs).float()

                # Store outputs
                y_prob = z.cpu().numpy()
                y_probs.extend(y_prob)

        return np.vstack(y_probs)
    
    def train(self, num_epochs, patience, train_dataloader, val_dataloader, 
              tolerance=1e-5):
        best_val_loss = np.inf
        training_stats = defaultdict(list)
        for epoch in tqdm(range(num_epochs)):
            # Steps
            train_loss = self.train_step(dataloader=train_dataloader)
            val_loss, _, _ = self.eval_step(dataloader=val_dataloader)
            #store stats
            training_stats['epoch'].append(epoch)
            training_stats['train_loss'].append(train_loss)
            training_stats['val_loss'].append(val_loss)
            #log-stats
            # wandb.init(project=f"{args.trail_id}_{args.dataset}_{args.data_type}",config=config_dict)
            if self.exp_tracker == 'wandb':
                log_metrics = {'epoch':epoch,'train_loss':train_loss,'val_loss':val_loss}
                wandb.log(log_metrics,step=epoch)
            
            self.scheduler.step(val_loss)

            # Early stopping
            if val_loss < best_val_loss - tolerance:
                best_val_loss = val_loss
                best_model = self.model
                _patience = patience  # reset _patience
            else:
                _patience -= 1
            if not _patience:  # 0
                print("Stopping early!")
                break

            # Tracking
            #mlflow.log_metrics({"train_loss": train_loss, "val_loss": val_loss}, step=epoch)

            # Logging
            if epoch%5 == 0:
                print(
                    f"Epoch: {epoch+1} | "
                    f"train_loss: {train_loss:.5f}, "
                    f"val_loss: {val_loss:.5f}, "
                    f"lr: {self.optimizer.param_groups[0]['lr']:.2E}, "
                    f"_patience: {_patience}"
                )
        if self.store_loc:
            pd.DataFrame(training_stats).to_csv(self.store_loc/'training_stats.csv',index=False)
        return best_model, best_val_loss
{% endraw %} {% raw %}
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=5)

trainer = Trainer(model,'cpu',loss_fn,optimizer,scheduler)
{% endraw %} {% raw %}
trainer.train(100,10,train_loader,valid_loader)
 11%|█         | 11/100 [00:00<00:01, 50.86it/s]
Epoch: 1 | train_loss: 8.55765, val_loss: 8.12049, lr: 1.00E-03, _patience: 10
Epoch: 6 | train_loss: 8.41908, val_loss: 8.07732, lr: 1.00E-03, _patience: 9
Epoch: 11 | train_loss: 8.28326, val_loss: 8.00117, lr: 1.00E-03, _patience: 8
Epoch: 16 | train_loss: 8.15473, val_loss: 7.88881, lr: 1.00E-03, _patience: 10
Epoch: 21 | train_loss: 8.01803, val_loss: 7.85543, lr: 1.00E-03, _patience: 8
 45%|████▌     | 45/100 [00:00<00:00, 93.60it/s]
Epoch: 26 | train_loss: 7.88548, val_loss: 7.76981, lr: 1.00E-03, _patience: 10
Epoch: 31 | train_loss: 7.75244, val_loss: 7.69328, lr: 1.00E-03, _patience: 9
Epoch: 36 | train_loss: 7.61687, val_loss: 7.61532, lr: 1.00E-03, _patience: 9
Epoch: 41 | train_loss: 7.48551, val_loss: 7.53750, lr: 1.00E-03, _patience: 9
Epoch: 46 | train_loss: 7.35090, val_loss: 7.42478, lr: 1.00E-03, _patience: 10
 65%|██████▌   | 65/100 [00:00<00:00, 96.05it/s]
Epoch: 51 | train_loss: 7.20997, val_loss: 7.34263, lr: 1.00E-03, _patience: 10
Epoch: 56 | train_loss: 7.07457, val_loss: 7.30977, lr: 1.00E-03, _patience: 9
Epoch: 61 | train_loss: 6.93769, val_loss: 7.23912, lr: 1.00E-03, _patience: 9
Epoch: 66 | train_loss: 6.80419, val_loss: 7.16288, lr: 1.00E-03, _patience: 9
 85%|████████▌ | 85/100 [00:00<00:00, 95.13it/s]
Epoch: 71 | train_loss: 6.66622, val_loss: 7.08776, lr: 1.00E-03, _patience: 8
Epoch: 76 | train_loss: 6.52944, val_loss: 6.96159, lr: 1.00E-03, _patience: 10
Epoch: 81 | train_loss: 6.38914, val_loss: 6.89294, lr: 1.00E-03, _patience: 10
Epoch: 86 | train_loss: 6.24970, val_loss: 6.86126, lr: 1.00E-03, _patience: 9
Epoch: 91 | train_loss: 6.10893, val_loss: 6.74132, lr: 1.00E-03, _patience: 10
100%|██████████| 100/100 [00:01<00:00, 87.60it/s]
Epoch: 96 | train_loss: 5.96172, val_loss: 6.66766, lr: 1.00E-03, _patience: 10

(NCF(
   (user_embd_mtrx): Embedding(3, 4)
   (item_embd_mtrx): Embedding(3, 4)
   (h): Linear(in_features=4, out_features=1, bias=True)
   (fst_lyr): Linear(in_features=8, out_features=4, bias=True)
   (snd_lyr): Linear(in_features=4, out_features=2, bias=True)
   (thrd_lyr): Linear(in_features=2, out_features=1, bias=True)
   (out_lyr): Linear(in_features=1, out_features=1, bias=True)
 ), 6.608727216720581)
{% endraw %}