--- title: Session (torch) Datasets keywords: fastai sidebar: home_sidebar summary: "Session-based recommendation datasets in PyTorch Dataset format." description: "Session-based recommendation datasets in PyTorch Dataset format." nb_path: "nbs/datasets/torch/session.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class Dataset[source]

Dataset(*args, **kwds) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class YoochooseDataset[source]

YoochooseDataset(*args, **kwds) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
dataset = YoochooseDataset(root='/content/yoochoose', maxlen=30)

sampler = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2, pin_memory=True)
samples = next(iter(sampler))
samples
Using existing file yoochoose.csv
Avg length: 10.0
Maximum length: 10
Number of sessions: 80183
Number of items: 12936
Number of actions: 406979
Average length of sessions: 5.075627003230111
[tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0, 10309, 10309, 10309],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,   794,  5005,  6891],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0, 10631,  4104,  9852],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,  9469,  9486,  9469],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,   155,  8790,  6931],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,   770, 11239,  6040],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0, 11641, 11610, 12033],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,  8449,  5705, 10331,   170,  8485]]),
 tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0, 10309, 10309, 10309, 10309],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,   794,  5005,  6891,  6501],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0, 10631,  4104,  9852, 10007],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,  9469,  9486,  9469,  9486],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,   155,  8790,  6931,  8821],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,   770, 11239,  6040, 11235],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0, 11641, 11610, 12033, 11638],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,  8449,  5705, 10331,   170,  8485, 10332]])]
{% endraw %} {% raw %}

class NowplayingDataset[source]

NowplayingDataset(*args, **kwds) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
dataset = NowplayingDataset(root='/content/nowplaying', maxlen=30)

sampler = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2, pin_memory=True)
samples = next(iter(sampler))
samples
Downloading https://github.com/RecoHut-Datasets/nowplaying/raw/v3/nowplaying.csv
Avg length: 20.0
Maximum length: 20
Number of sessions: 113918
Number of items: 239221
Number of actions: 1184815
Average length of sessions: 10.400595164943205
[tensor([[     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0, 114002, 113983,  89621, 113960, 113884, 113926, 114000, 113738,
          113930,   3168, 113805, 113800, 113789, 113872, 114018, 113881, 113869,
          113776,  21568],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,  13653,  11910,  28131,
            4896,  33231],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0, 217911, 218397,  23439,  23684,  40048,  23439,  22123, 218298,
           58345, 218399],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,  16292,   3786,  45272,   3574,  28015,  16926,
           27992,  33024],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
           76624,  76624, 113070, 113070,  76624,  76624, 113070, 113070,  76624,
           76624, 113070],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0, 127696,  39067,  82151,  29706,  29201,  29605,   4791,  29298,
          127939,  29456,  29779, 109896,   5945,  73638, 127962,  44011,  29721,
          127625, 114913],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,  32544, 155346, 155508, 155074, 155886, 155223, 155360, 155356,
          154929, 154914, 155887, 154877, 155115, 155888,  75852, 154969, 155889,
          155291, 155890],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,  56331,  56332,
           56333,  56334,  56335,  13547,  56336,   7363,  56337,  56338,  56339,
            4735,   9554]]),
 tensor([[     0,      0,      0,      0,      0,      0,      0,      0,      0,
          114002, 113983,  89621, 113960, 113884, 113926, 114000, 113738, 113930,
            3168, 113805, 113800, 113789, 113872, 114018, 113881, 113869, 113776,
           21568, 113925],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,  13653,  11910,  28131,   4896,
           33231,   8409],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
          217911, 218397,  23439,  23684,  40048,  23439,  22123, 218298,  58345,
          218399, 218330],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,  16292,   3786,  45272,   3574,  28015,  16926,  27992,
           33024,  24359],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,  76624,
           76624, 113070, 113070,  76624,  76624, 113070, 113070,  76624,  76624,
          113070, 113070],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
          127696,  39067,  82151,  29706,  29201,  29605,   4791,  29298, 127939,
           29456,  29779, 109896,   5945,  73638, 127962,  44011,  29721, 127625,
          114913, 127608],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
           32544, 155346, 155508, 155074, 155886, 155223, 155360, 155356, 154929,
          154914, 155887, 154877, 155115, 155888,  75852, 154969, 155889, 155291,
          155890, 155891],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,  56331,  56332,  56333,
           56334,  56335,  13547,  56336,   7363,  56337,  56338,  56339,   4735,
            9554,  56340]])]
{% endraw %} {% raw %}

class DigineticaDataset[source]

DigineticaDataset(*args, **kwds) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
dataset = DigineticaDataset(root='/content/diginetica', maxlen=30)

sampler = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2, pin_memory=True)
samples = next(iter(sampler))
samples
Downloading https://github.com/RecoHut-Datasets/diginetica/raw/v4/diginetica.csv
Avg length: 8.777109003245833
Maximum length: 70
Number of sessions: 63466
Number of items: 38970
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order)
Number of actions: 557048
Average length of sessions: 8.777109003245833
[tensor([[    0,     0,     0,     0,     0,     0,     0,     0,  2387,  2245,
           9141,  2366,  9142,  9143,  3193,  3193,  1726,  1725,  2366,  1722,
           2366,  2366,  9144,  3197,  9145,  1722,  9146,  9147,  9146],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0, 17095, 17101, 17094, 17100, 17096],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0, 13816,
          10789,  9204, 11198, 23151,  8289, 30676,  3372, 30678, 14125],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,  4814, 10013, 18788,  9285, 14081],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0, 13064, 36257, 30911, 11052],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0, 28073, 19585,
          17214, 13842, 28815, 28815, 13842, 17214, 19585, 16278, 15659],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0, 16291, 17881,  5630, 20969, 20969, 20829, 12938],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,  7298, 25445,  1397, 25445,  7298,  3191,  1735]]),
 tensor([[    0,     0,     0,     0,     0,     0,     0,  2387,  2245,  9141,
           2366,  9142,  9143,  3193,  3193,  1726,  1725,  2366,  1722,  2366,
           2366,  9144,  3197,  9145,  1722,  9146,  9147,  9146,  9148],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0, 17095, 17101, 17094, 17100, 17096, 17102],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0, 13816, 10789,
           9204, 11198, 23151,  8289, 30676,  3372, 30678, 14125, 14124],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,  4814, 10013, 18788,  9285, 14081,   336],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0, 13064, 36257, 30911, 11052,  7098],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0, 28073, 19585, 17214,
          13842, 28815, 28815, 13842, 17214, 19585, 16278, 15659, 28072],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0, 16291, 17881,  5630, 20969, 20969, 20829, 12938, 12938],
         [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,  7298, 25445,  1397, 25445,  7298,  3191,  1735, 25445]])]
{% endraw %} {% raw %}

class LastfmDataset[source]

LastfmDataset(*args, **kwds) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
dataset = LastfmDataset(root='/content/lastfm', maxlen=30)

sampler = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2, pin_memory=True)
samples = next(iter(sampler))
samples
Downloading https://github.com/RecoHut-Datasets/lastfm/raw/v2/last_fm.csv
Avg length: 17.447849599510228
Maximum length: 49
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order)
Number of sessions: 196010
Number of items: 107391
Number of actions: 3419953
Average length of sessions: 17.447849599510228
[tensor([[     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,   3821,     96,   3821,   1600,     96,
            3366,   3821,     96,   3366,  18639,   3821,   3280,   3366,     96,
            3366,   1600],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,  13864,
           35393,  13864,  50765,  13743,  51628,  34165,  44702,  62996,   9504,
          106404,  13864],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,   3363,   1875,
            2782,   1875],
         [   973,   2740,   2712,  17892,   2228,    829,   2740,    128,    744,
            1193,   1284,   2755,   1443,   4028,   2712,   7635,    620,   1861,
            3978,    790,    916,   1455,    227,  10492,   1257,    633,  29590,
            2631,   8419],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,  13576,   9525,   9513,
            9857,   9512,   9180,   9335,   1953,   9504,   6466,  19980,  19966,
           71404,  90684],
         [  8405,   5743,   8106,  35491,    675,   3393,   8239,   3469,   8239,
             675,   8405,   8239,   2895,  19781,  16754,   3572,  16754,  16088,
             637,   3713,   8171,    675,  15850,   3135,   3713,   8509,   1903,
            1900,  13521],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,   8771,   8510,   1409,   3180,   8407,   5464,
             479,    707],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,   3366,    987,    988,    987,  47629,    707,   1644,
            3775,   1291]]),
 tensor([[     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,   3821,     96,   3821,   1600,     96,   3366,
            3821,     96,   3366,  18639,   3821,   3280,   3366,     96,   3366,
            1600,   3366],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,  13864,  35393,
           13864,  50765,  13743,  51628,  34165,  44702,  62996,   9504, 106404,
           13864,  20385],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,   3363,   1875,   2782,
            1875,   9683],
         [  2740,   2712,  17892,   2228,    829,   2740,    128,    744,   1193,
            1284,   2755,   1443,   4028,   2712,   7635,    620,   1861,   3978,
             790,    916,   1455,    227,  10492,   1257,    633,  29590,   2631,
            8419,   8555],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,  13576,   9525,   9513,   9857,
            9512,   9180,   9335,   1953,   9504,   6466,  19980,  19966,  71404,
           90684,  71404],
         [  5743,   8106,  35491,    675,   3393,   8239,   3469,   8239,    675,
            8405,   8239,   2895,  19781,  16754,   3572,  16754,  16088,    637,
            3713,   8171,    675,  15850,   3135,   3713,   8509,   1903,   1900,
           13521,   8509],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,   8771,   8510,   1409,   3180,   8407,   5464,    479,
             707,    937],
         [     0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,   3366,    987,    988,    987,  47629,    707,   1644,   3775,
            1291,   3775]])]
{% endraw %}