--- title: Base dataset keywords: fastai sidebar: home_sidebar summary: "Base class for dataset module." description: "Base class for dataset module." nb_path: "nbs/datasets/base.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

AbstractDataset

v1

{% raw %}

class AbstractDataset[source]

AbstractDataset(args)

{% endraw %} {% raw %}
{% endraw %}

v2

{% raw %}

class AbstractDatasetv2[source]

AbstractDatasetv2(args)

{% endraw %} {% raw %}
{% endraw %}

Dataset

SessionDataset

v1

{% raw %}

class SessionDataset[source]

SessionDataset(root, min_session_length:int=None, min_item_support:int=None, eval_sec:int=86400) :: Dataset

Session data base class.

Args: min_session_length (int): Minimum number of items for a session to be valid min_item_support (int): Minimum number of interactions for an item to be valid eval_sec (int): these many seconds from the end will be taken as validation data

References:

1. https://github.com/Ethan-Yys/GRU4REC-pytorch-master/blob/master/preprocessing.py
{% endraw %} {% raw %}
{% endraw %} {% raw %}
class YoochooseDataset(SessionDataset):
    data_id = '1UEcKC4EfgMVD2n_zBvAyp0vRNyv7ndSF'

    def __init__(self,
                 root,
                 min_session_length: int = 2,
                 min_item_support: int = 5,
                 eval_sec: int = 86400,
                 ):
        super().__init__(root, min_session_length, min_item_support, eval_sec)

    @property
    def raw_file_names(self) -> str:
        return 'rsc15-clicks.dat'

    @property
    def processed_file_names(self) -> str:
        return ['yoochoose_train.txt','yoochoose_valid.txt']

    def download(self):
        from google_drive_downloader import GoogleDriveDownloader as gdd
        from shutil import move, rmtree

        path = osp.join(self.raw_dir, 'rsc15.zip')
        gdd.download_file_from_google_drive(self.data_id, path)
        extract_zip(path, self.raw_dir)
        move(osp.join(self.raw_dir, 'rsc15', 'raw', ),
             osp.join(self.raw_dir, self.raw_file_names))
        rmtree(osp.join(self.raw_dir, 'rsc15'))
        os.unlink(path)

    def process(self):
        df = self.load_ratings_df()
        if self.min_session_length is not None:
            df = self.remove_short_sessions(df)
        if self.min_item_support is not None:
            df = self.remove_sparse_items(df)
        train, test = self.split_df(df)
        train.to_csv(self.processed_paths[0], sep=',', index=False)
        test.to_csv(self.processed_paths[1], sep=',', index=False)

    def load_ratings_df(self):
        df = pd.read_csv(self.raw_paths[0], header=None, usecols=[0, 1, 2],
                         dtype={0: np.int32, 1: str, 2: np.int64})
        df.columns = ['uid', 'timestamp', 'sid']
        df['timestamp'] = df['timestamp'].apply(lambda x: datetime.datetime.strptime(
            x, '%Y-%m-%dT%H:%M:%S.%fZ').timestamp())
        return df
{% endraw %} {% raw %}
ds = YoochooseDataset(root='/content/yoochoose')
Processing...
Training Set has 31637239 Events, 7966257 Sessions, and 37483 Items


Validation Set has 71222 Events, 15324 Sessions, and 6751 Items


Done!
{% endraw %} {% raw %}
!tree --du -h -C /content/yoochoose
/content/yoochoose
├── [995M]  processed
│   ├── [993M]  yoochoose_train.txt
│   └── [2.3M]  yoochoose_valid.txt
└── [1.4G]  raw
    └── [1.4G]  rsc15-clicks.dat

 2.4G used in 2 directories, 3 files
{% endraw %}

v2

{% raw %}

class SessionDatasetv2[source]

SessionDatasetv2(root, column_names) :: Dataset

Dataset base class

{% endraw %} {% raw %}
{% endraw %}

v3

{% raw %}

class SessionDatasetv3[source]

SessionDatasetv3(root) :: Dataset

Dataset base class

{% endraw %} {% raw %}
{% endraw %}

v4

{% raw %}

class SessionDatasetv4[source]

SessionDatasetv4(root, process_method, min_date=None, session_length=None, min_session_length=None, min_item_support=None, num_slices=None, days_offset=None, days_shift=None, days_train=None, days_test=None, data=None) :: Dataset

Session dataset base class.

Args: root (string): Root directory where the dataset should be saved. process_method (string): last: last day => test set last_min_date: last day => test set, but from a minimal date onwards days_test: last N days => test set slice: create multiple train-test-combinations with a sliding window approach min_date (string): Minimum date session_length (int): Session time length :default = 30 * 60 #30 minutes min_session_length (int): Minimum number of items for a session to be valid min_item_support (int): Minimum number of interactions for an item to be valid num_slices (int): Offset in days from the first date in the data set days_offset (int): Number of days the training start date is shifted after creating one slice days_shift (int): Days shift days_train (int): Days in train set in each slice days_test (int): Days in test set in each slice

{% endraw %} {% raw %}
{% endraw %}

Graph Dataset

{% raw %}

class GraphDataset[source]

GraphDataset(data, shuffle=False, graph=None)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
train_data = ([[1, 2, 3], [2, 3, 4], [1, 2, 4], [2, 3], [1]], 
              [4, 5, 5, 4, 2])

tds = GraphData(train_data, shuffle=False)
print(tds.generate_batch(1))
print(tds.generate_batch(2))
print(tds.inputs)

tds = GraphData(train_data, shuffle=True)
print(tds.generate_batch(1))
print(tds.generate_batch(2))
print(tds.inputs)
[array([0]), array([1]), array([2]), array([3]), array([4])]
[array([0, 1]), array([2, 3]), array([4])]
[[1 2 3]
 [2 3 4]
 [1 2 4]
 [2 3 0]
 [1 0 0]]
[array([0]), array([1]), array([2]), array([3]), array([4])]
[array([0, 1]), array([2, 3]), array([4])]
[[1 2 4]
 [2 3 4]
 [1 0 0]
 [2 3 0]
 [1 2 3]]
{% endraw %}

SessionGraph Dataset

{% raw %}

class SessionGraphDataset[source]

SessionGraphDataset(root, shuffle=False, n_node=None) :: Dataset

References

1. COTREC session-based recommender model training. https://t.ly/cXTH.
{% endraw %} {% raw %}
{% endraw %} {% raw %}
class DigineticaDataset(SessionGraphDataset):
    train_url = "https://github.com/RecoHut-Datasets/diginetica/raw/v2/train.txt"
    test_url = "https://github.com/RecoHut-Datasets/diginetica/raw/v2/test.txt"
    all_train_seq_url = "https://github.com/RecoHut-Datasets/diginetica/raw/v2/all_train_seq.txt"

    def __init__(self, root, shuffle=False, n_node=43097, is_train=True):
        self.n_node = n_node
        self.shuffle = shuffle
        self.is_train = is_train
        super().__init__(root, shuffle, n_node)

    @property
    def raw_file_names(self) -> str:
        if self.is_train:
            return ['train.txt', 'all_train_seq.txt']
        return ['test.txt', 'all_train_seq.txt']

    def download(self):
        download_url(self.all_train_seq_url, self.raw_dir)
        if self.is_train:
            download_url(self.train_url, self.raw_dir)
        else:
            download_url(self.test_url, self.raw_dir)
{% endraw %} {% raw %}
root = '/content/diginetica'

train_data = DigineticaDataset(root=root, shuffle=True, is_train=True)
test_data = DigineticaDataset(root=root, shuffle=False, is_train=False)
{% endraw %} {% raw %}
class TmallDataset(SessionGraphDataset):
    train_url = "https://github.com/RecoHut-Datasets/tmall/raw/v1/train.txt"
    test_url = "https://github.com/RecoHut-Datasets/tmall/raw/v1/test.txt"
    all_train_seq_url = "https://github.com/RecoHut-Datasets/tmall/raw/v1/all_train_seq.txt"

    def __init__(self, root, shuffle=False, n_node=40727, is_train=True):
        self.n_node = n_node
        self.shuffle = shuffle
        self.is_train = is_train
        super().__init__(root, shuffle, n_node)

    @property
    def raw_file_names(self) -> str:
        if self.is_train:
            return ['train.txt', 'all_train_seq.txt']
        return ['test.txt', 'all_train_seq.txt']

    def download(self):
        download_url(self.all_train_seq_url, self.raw_dir)
        if self.is_train:
            download_url(self.train_url, self.raw_dir)
        else:
            download_url(self.test_url, self.raw_dir)
{% endraw %} {% raw %}
root = '/content/tmall'

train_data = TmallDataset(root=root, shuffle=True, is_train=True)
test_data = TmallDataset(root=root, shuffle=False, is_train=False)
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order)
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order)
{% endraw %} {% raw %}
class RetailRocketDataset(SessionGraphDataset):
    train_url = "https://github.com/RecoHut-Datasets/retail_rocket/raw/v1/train.txt"
    test_url = "https://github.com/RecoHut-Datasets/retail_rocket/raw/v1/test.txt"
    all_train_seq_url = "https://github.com/RecoHut-Datasets/retail_rocket/raw/v1/all_train_seq.txt"

    def __init__(self, root, shuffle=False, n_node=40727, is_train=True):
        self.n_node = n_node
        self.shuffle = shuffle
        self.is_train = is_train
        super().__init__(root, shuffle, n_node)

    @property
    def raw_file_names(self) -> str:
        if self.is_train:
            return ['train.txt', 'all_train_seq.txt']
        return ['test.txt', 'all_train_seq.txt']

    def download(self):
        download_url(self.all_train_seq_url, self.raw_dir)
        if self.is_train:
            download_url(self.train_url, self.raw_dir)
        else:
            download_url(self.test_url, self.raw_dir)
{% endraw %} {% raw %}
root = '/content/retail_rocket'

train_data = RetailRocketDataset(root=root, shuffle=True, is_train=True)
test_data = RetailRocketDataset(root=root, shuffle=False, is_train=False)
Downloading https://github.com/RecoHut-Datasets/retail_rocket/raw/v1/all_train_seq.txt
Downloading https://github.com/RecoHut-Datasets/retail_rocket/raw/v1/train.txt
Using existing file all_train_seq.txt
Downloading https://github.com/RecoHut-Datasets/retail_rocket/raw/v1/test.txt
{% endraw %} {% raw %}
class SampleDataset(SessionGraphDataset):
    train_url = "https://github.com/RecoHut-Datasets/sample_session/raw/v2/train.txt"
    test_url = "https://github.com/RecoHut-Datasets/sample_session/raw/v2/test.txt"
    all_train_seq_url = "https://github.com/RecoHut-Datasets/sample_session/raw/v2/all_train_seq.txt"

    def __init__(self, root, shuffle=False, n_node=309, is_train=True):
        self.n_node = n_node
        self.shuffle = shuffle
        self.is_train = is_train
        super().__init__(root, shuffle, n_node)

    @property
    def raw_file_names(self) -> str:
        if self.is_train:
            return ['train.txt', 'all_train_seq.txt']
        return ['test.txt', 'all_train_seq.txt']

    def download(self):
        download_url(self.all_train_seq_url, self.raw_dir)
        if self.is_train:
            download_url(self.train_url, self.raw_dir)
        else:
            download_url(self.test_url, self.raw_dir)
{% endraw %} {% raw %}
root = '/content/sample'

train_data = SampleDataset(root=root, shuffle=True, is_train=True)
test_data = SampleDataset(root=root, shuffle=False, is_train=False)
Downloading https://github.com/RecoHut-Datasets/sample_session/raw/v2/all_train_seq.txt
Downloading https://github.com/RecoHut-Datasets/sample_session/raw/v2/train.txt
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order)
Using existing file all_train_seq.txt
Downloading https://github.com/RecoHut-Datasets/sample_session/raw/v2/test.txt
/usr/local/lib/python3.7/dist-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order)
{% endraw %}

Rating Dataset

{% raw %}

class RatingDataset[source]

RatingDataset(root, min_uc, min_sc, split='leave_one_out', dataset_split_seed=42, eval_set_size=None, min_rating=None, iterative_triplet=False) :: Dataset

Interaction data with rating feedback

Args: root: data folder path min_uc: minimum user count to keep in the data min_sc: minimum item count to keep in the data split: data split method - leave_one_out/holdout min_rating: minimum rating threshold to convert explicit feedback into implicit

References:

1. https://github.com/Yueeeeeeee/RecSys-Extraction-Attack/tree/main/datasets
{% endraw %} {% raw %}
{% endraw %} {% raw %}
class AmazonGamesDataset(RatingDataset):
    url = "http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Video_Games.csv"

    @property
    def raw_file_names(self):
        return 'ratings_Video_Games.csv'

    def download(self):
        download_url(self.url, self.raw_dir)

    def load_ratings_df(self):
        df = pd.read_csv(self.raw_paths[0], header=None)
        df.columns = ['uid', 'sid', 'rating', 'timestamp']
        return df
{% endraw %} {% raw %}
ds = AmazonGamesDataset(root='/content/amazon_games', min_uc=10, min_sc=5)
Processing...
Filtering triplets
Densifying index
Splitting
100%|██████████| 7519/7519 [00:02<00:00, 2512.75it/s]
Done!
{% endraw %} {% raw %}
class AmazonBeautyDataset(RatingDataset):
    url = "http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Beauty.csv"

    @property
    def raw_file_names(self):
        return 'ratings_Beauty.csv'

    def download(self):
        download_url(self.url, self.raw_dir)

    def load_ratings_df(self):
        df = pd.read_csv(self.raw_paths[0], header=None)
        df.columns = ['uid', 'sid', 'rating', 'timestamp']
        return df
{% endraw %} {% raw %}
ds = AmazonGamesDataset(root='/content/amazon_beauty', min_uc=10, min_sc=5)
Downloading http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Video_Games.csv
Processing...
Filtering triplets
Densifying index
Splitting
100%|██████████| 7519/7519 [00:02<00:00, 2527.69it/s]
Done!
{% endraw %} {% raw %}
class ML1mDataset(RatingDataset):
    url = "http://files.grouplens.org/datasets/movielens/ml-1m.zip"

    @property
    def raw_file_names(self):
        return 'ratings.dat'

    def download(self):
        path = download_url(self.url, self.raw_dir)
        extract_zip(path, self.raw_dir)
        from shutil import move, rmtree
        move(osp.join(self.raw_dir, 'ml-1m', self.raw_file_names), self.raw_dir)
        rmtree(osp.join(self.raw_dir, 'ml-1m'))
        os.unlink(path)

    def load_ratings_df(self):
        df = pd.read_csv(self.raw_paths[0], sep='::', header=None, engine='python')
        df.columns = ['uid', 'sid', 'rating', 'timestamp']
        return df
{% endraw %} {% raw %}
ds = ML1mDataset(root='/content/ML1m', min_uc=10, min_sc=5)
Processing...
Filtering triplets
Densifying index
Splitting
100%|██████████| 6040/6040 [00:02<00:00, 2590.97it/s]
Done!
{% endraw %} {% raw %}
class SteamGamesDataset(RatingDataset):
    url = "http://cseweb.ucsd.edu/~wckang/steam_reviews.json.gz"

    @property
    def raw_file_names(self):
        return 'steam_reviews.json'

    def download(self):
        path = download_url(self.url, self.raw_dir)
        extract_gz(path, self.raw_dir)
        os.unlink(path)

    def load_ratings_df(self):
        data = []
        f = open(self.raw_paths[0], 'r', encoding='utf-8')
        import ast
        for line in f.readlines():
            temp = ast.literal_eval(line)
            data.append([temp['username'], temp['product_id'], temp['date']])

        return pd.DataFrame(data, columns=['uid', 'sid', 'timestamp'])
{% endraw %} {% raw %}
ds = SteamGamesDataset(root='/content/steam', min_uc=10, min_sc=5)
Processing...
Filtering triplets
Densifying index
Splitting
100%|██████████| 120145/120145 [01:10<00:00, 1709.62it/s]
Done!
{% endraw %} {% raw %}
class YoochooseDataset(RatingDataset):
    url = "https://s3-eu-west-1.amazonaws.com/yc-rdata/yoochoose-data.7z"

    @property
    def raw_file_names(self):
        return 'yoochoose-clicks.dat'

    def download(self):
        path = download_url(self.url, self.raw_dir)
        # pip install pyunpack patool
        import pyunpack
        pyunpack.Archive(path).extractall(self.raw_dir)
        os.unlink(path)

    def load_ratings_df(self):
        df = pd.read_csv(self.raw_paths[0], header=None)
        df.columns = ['uid', 'timestamp', 'sid', 'category']
        return df
{% endraw %} {% raw %}
ds = YoochooseDataset(root='/content/yoochoose', min_uc=10, min_sc=5)
Processing...
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:28: DtypeWarning: Columns (3) have mixed types.Specify dtype option on import or set low_memory=False.
Filtering triplets
Densifying index
Splitting
100%|██████████| 449961/449961 [03:55<00:00, 1913.52it/s]
Done!
{% endraw %}