Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions asset/model_list.json
Original file line number Diff line number Diff line change
Expand Up @@ -2214,6 +2214,20 @@
"ref_code": "https://github.com/iesl/softmax_CPR_recommend",
"repository": "RecBole",
"repo_link": "https://github.com/RUCAIBox/RecBole"
},
{
"category": "Sequential Recommendation",
"cate_link": "/docs/user_guide/model_intro.html#sequential-recommendation",
"year": "2024",
"pub": "TOIS'24",
"model": "TriMLP",
"model_link": "/docs/user_guide/model/sequential/trimlp.html",
"paper": "TriMLP: A Foundational MLP-Like Architecture for Sequential Recommendation",
"paper_link": "https://dl.acm.org/doi/10.1145/3670995",
"authors": "Yiheng Jiang, Yuanbo Xu, Yongjian Yang, Funing Yang, Pengyang Wang, Chaozhuo Li, Fuzhen Zhuang, Hui Xiong",
"ref_code": "https://github.com/jiangyiheng1/TriMLP",
"repository": "RecBole",
"repo_link": "https://github.com/RUCAIBox/RecBole"
}
]
}
Binary file added docs/source/asset/trimlp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Introduction
RecBole is a unified, comprehensive and efficient framework developed based on PyTorch.
It aims to help the researchers to reproduce and develop recommendation models.

In the lastest release, our library includes 94 recommendation algorithms `[Model List]`_, covering four major categories:
In the lastest release, our library includes 95 recommendation algorithms `[Model List]`_, covering four major categories:

- General Recommendation
- Sequential Recommendation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ recbole.model.sequential\_recommender
recbole.model.sequential_recommender.fearec
recbole.model.sequential_recommender.gru4reccpr
recbole.model.sequential_recommender.sasreccpr
recbole.model.sequential_recommender.trimlp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: recbole.model.sequential_recommender.trimlp
:members:
:undoc-members:
:show-inheritance:
79 changes: 79 additions & 0 deletions docs/source/user_guide/model/sequential/trimlp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
TriMLP
===========

Introduction
---------------------

`[paper] <https://dl.acm.org/doi/10.1145/3670995>`_

**Title:** TriMLP: A Foundational MLP-Like Architecture for Sequential Recommendation

**Authors:** Yiheng Jiang, Yuanbo Xu, Yongjian Yang, Funing Yang, Pengyang Wang, Chaozhuo Li, Fuzhen Zhuang, Hui Xiong

**Abstract:** In this work, we present TriMLP as a foundational MLP-like architecture for the sequential recommendation, simultaneously achieving computational efficiency and promising performance. First, we empirically study the incompatibility between existing purely MLP-based models and sequential recommendation, that the inherent fully-connective structure endows historical user–item interactions (referred as tokens) with unrestricted communications and overlooks the essential chronological order in sequences. Then, we propose the MLP-based Triangular Mixer to establish ordered contact among tokens and excavate the primary sequential modeling capability under the standard auto-regressive training fashion. It contains (1) a global mixing layer that drops the lower-triangle neurons in MLP to block the anti-chronological connections from future tokens and (2) a local mixing layer that further disables specific upper-triangle neurons to split the sequence as multiple independent sessions. The mixer serially alternates these two layers to support fine-grained preferences modeling, where the global one focuses on the long-range dependency in the whole sequence, and the local one calls for the short-term patterns in sessions. Experimental results on 12 datasets of different scales from 4 benchmarks elucidate that TriMLP consistently attains favorable accuracy/efficiency tradeoff over all validated datasets, where the average performance boost against several state-of-the-art baselines achieves up to 14.88%, and the maximum reduction of inference time reaches 23.73%. The intriguing properties render TriMLP a strong contender to the well-established RNN-, CNN-, and Transformer-based sequential recommenders. Code is available at https://github.com/jiangyiheng1/TriMLP.

.. image:: ../../../asset/trimlp.png
:width: 500
:align: center

Running with RecBole
-------------------------

**Model Hyper-Parameters:**

- ``embedding_size (int)`` : The embedding size of items. Defaults to ``64``.
- ``act_fct (float)`` : The activation function in feed-forward layer. Defaults to ``'None'``. Range in ``['None', 'tanh', 'sigmoid']``.
- ``num_session (float)`` : The number of sessions per sequence. Defaults to ``2``.
- ``dropout_prob (float)`` : The dropout rate. Defaults to ``0.5``.
- ``loss_type (str)`` : The type of loss function. Is fixed to ``'CE'``.



**A Running Example:**

Write the following code to a python file, such as `run.py`

.. code:: python

from recbole.quick_start import run_recbole

parameter_dict = {
'train_neg_sample_args': None,
}
run_recbole(model='TriMLP', dataset='ml-100k', config_dict=parameter_dict)

And then:

.. code:: bash

python run.py

Tuning Hyper Parameters
-------------------------

If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``.

.. code:: bash

learning_rate choice [0.01, 0.005, 0.001, 0.0005, 0.0001]
act_fct choice ['None', 'tanh', 'sigmoid']
dropout_prob choice [0.2, 0.5]
num_session choice [1, 2, 3, 4, 8, 16, 32]

Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model.

Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning:

.. code:: bash

python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test

For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`.


If you want to change parameters, dataset or evaluation settings, take a look at

- :doc:`../../../user_guide/config_settings`
- :doc:`../../../user_guide/data_intro`
- :doc:`../../../user_guide/train_eval_intro`
- :doc:`../../../user_guide/usage`
1 change: 1 addition & 0 deletions docs/source/user_guide/model_intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ the sequential data. The models of session-based recommendation are also include
model/sequential/fearec
model/sequential/sasreccpr
model/sequential/gru4reccpr
model/sequential/trimlp


Knowledge-based Recommendation
Expand Down
1 change: 1 addition & 0 deletions recbole/model/sequential_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from recbole.model.sequential_recommender.stamp import STAMP
from recbole.model.sequential_recommender.transrec import TransRec
from recbole.model.sequential_recommender.fearec import FEARec
from recbole.model.sequential_recommender.trimlp import TriMLP
127 changes: 127 additions & 0 deletions recbole/model/sequential_recommender/trimlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
# @Time : 2024/09/26 12:19
# @Author : Andreas Peintner
# @Email : anpeintner@gmail.com

r"""
TriMLP
################################################

Reference:
Jiang et al. "TriMLP: A Foundational MLP-like Architecture for Sequential Recommendation" in TOIS 2024.

Reference code:
https://github.com/jiangyiheng1/TriMLP/
"""

import torch
from torch import nn

from recbole.model.abstract_recommender import SequentialRecommender

def global_kernel(seq_len):
mask = torch.triu(torch.ones([seq_len, seq_len]))
matrix = torch.ones([seq_len, seq_len])
matrix = matrix.masked_fill(mask == 0.0, -1e9)
kernel = nn.parameter.Parameter(matrix, requires_grad=True)
return kernel


def local_kernel(seq_len, n_session):
mask = torch.zeros([seq_len, seq_len])
for i in range(0, seq_len, seq_len // n_session):
mask[i:i + seq_len // n_session, i:i + seq_len // n_session] = torch.ones(
[seq_len // n_session, seq_len // n_session])
mask = torch.triu(mask)
matrix = torch.ones([seq_len, seq_len])
matrix = matrix.masked_fill(mask == 0.0, -1e9)
kernel = nn.parameter.Parameter(matrix, requires_grad=True)
return kernel

class TriMixer(nn.Module):
def __init__(self, seq_len, n_session, act=nn.Sigmoid()):
super().__init__()
assert seq_len % n_session == 0
self.l = seq_len
self.n_s = n_session
self.act = act
self.local_mixing = local_kernel(self.l, self.n_s)
self.global_mixing = global_kernel(self.l)

def forward(self, x):
x = torch.matmul(x.permute(0, 2, 1), self.global_mixing.softmax(dim=-1))
if self.act:
x = self.act(x)

x = torch.matmul(x, self.local_mixing.softmax(dim=-1)).permute(0, 2, 1)
if self.act:
x = self.act(x)

return x

def extra_repr(self):
return f"seq_len={self.l}, n_session={self.n_s}, act={self.act}"


class TriMLP(SequentialRecommender):
r"""TriMLP: A Foundational MLP-like Architecture for Sequential Recommendation
"""

def __init__(self, config, dataset):
super(TriMLP, self).__init__(config, dataset)

# load parameters info
self.embedding_size = config["embedding_size"]
self.loss_type = config["loss_type"]

if config["act_fct"] == "sigmoid":
self.act_fct = nn.Sigmoid()
elif config["act_fct"] == "tanh":
self.act_fct = nn.Tanh()
else:
self.act_fct = None

self.dropout_prob = config["dropout_prob"]
self.num_session = config["num_session"]

# define layers and loss
self.item_embedding = nn.Embedding(
self.n_items, self.embedding_size, padding_idx=0
)
self.emb_dropout = nn.Dropout(self.dropout_prob)
self.mixer = TriMixer(self.max_seq_length, self.num_session, act=self.act_fct)
self.final_layer = nn.Linear(self.embedding_size, self.n_items)

self.loss_fct = nn.CrossEntropyLoss(ignore_index=0)

def forward(self, item_seq, item_seq_len):
item_seq_emb = self.item_embedding(item_seq)
item_seq_emb_dropout = self.emb_dropout(item_seq_emb)

mixer_output = self.mixer(item_seq_emb_dropout)
seq_output = self.gather_indexes(mixer_output, item_seq_len - 1)
seq_output = self.final_layer(seq_output)

return seq_output

def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
scores = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
loss = self.loss_fct(scores, pos_items)
return loss

def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
scores = self.forward(item_seq, item_seq_len).unsqueeze(-1)
scores = self.gather_indexes(scores, test_item).squeeze(-1)
return scores

def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
scores = self.forward(item_seq, item_seq_len)
return scores
5 changes: 5 additions & 0 deletions recbole/properties/model/TriMLP.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
embedding_size: 64
act_fct: None # None or sigmoid or tanh
num_session: 2
dropout_prob: 0.5
loss_type: 'CE'
7 changes: 7 additions & 0 deletions tests/model/test_model_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,13 @@ def test_fea_rec(self):
}
quick_test(config_dict)

def test_trimlp(self):
config_dict = {
"model": "TriMLP",
"train_neg_sample_args": None,
}
quick_test(config_dict)

# def test_gru4reckg(self):
# config_dict = {
# 'model': 'GRU4RecKG',
Expand Down
Loading