Session-aware BERT4Rec
Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 short.
Everything in the paper is implemented (including vanilla BERT4Rec and SASRec), and can be reproduced.
Usage
1. Build Docker
./scripts/build.sh
2. Download dataset
Download corresponding datasets into some directory, such as ./roughs.
For Steam dataset, use version 2.
Rename datasets: 'ml1m' for MovieLens-1M, 'ml20m' for MovieLens-2M, 'steam2' for Steam.
3. Preprocess
--rough_root: for original dataset files--data_root: for processed data files
python preprocess.py prepare ml1m --data_root ./data --rough_root ./roughs
python preprocess.py prepare ml20m --data_root ./data --rough_root ./roughs
python preprocess.py prepare steam2 --data_root ./data --rough_root ./roughs
For some stats:
python preprocess.py count stats --data_root ./data --rough_root ./roughs > dstats.tsv
4. Run
See default configuration setting in entry.py.
To modify configuration, make some directory under runs/ like ./runs/ml1m/bert4rec/vanilla/, and create config.json.
Sample Run Script
My x0.sh file that uses GPU No. 0:
runpy () {
docker run \
-it \
--rm \
--init \
--gpus '"device=0"' \
--shm-size 16G \
--volume="$HOME/.cache/torch:/root/.cache/torch" \
--volume="$PWD:/workspace" \
session-aware-bert4rec \
python "$@"
}
runpy entry.py ml1m/bert4rec/vanilla
Terminologies
The df_ prefix always means DataFrame from Pandas.
uid(str|int): User ID (unique).iid(str|int): Item ID (unique).sid(str|int): Session ID (unique), used only for session separation.uindex(int): mapped index number of User ID, 1 ~ n.iindex(int): mapped index number of Item ID, 1 ~ m.timestamp(int): UNIX timestamp.
Data Files
After preprocessing, we'll have followings in each data/:dataset_name/ directory.
uid2uindex.pkl(dict): {uid→uindex}.iid2iindex.pkl(dict): {iid→iindex}.df_rows.pkl(df): column of (uindex,iindex,sid,timestamp), with no index.train.pkl(dict): {uindex→ [list of (iindex,sid,timestamp)]}.valid.pkl(dict): {uindex→ [list of (iindex,sid,timestamp)]}.test.pkl(dict): {uindex→ [list of (iindex,sid,timestamp)]}.ns_random.pkl(dict): {uindex-> [list ofiindex]}.ns_popular.pkl(dict): {uindex-> [list ofiindex]}.