WRENCH: Weak supeRvision bENCHmark

Overview

made-with-python Maintenance license repo size Total lines visitors GitHub stars GitHub forks Arxiv

πŸ”§ What is it?

Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.

For more information, checkout our publications:

If you find this repository helpful, feel free to cite our publication:

@misc{zhang2021wrench,
      title={WRENCH: A Comprehensive Benchmark for Weak Supervision}, 
      author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
      year={2021},
      eprint={2109.11377},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

πŸ”§ What is weak supervision?

Weak Supervision is a paradigm for automated training data creation without manual annotations.

For a brief overview, please check out this blog.

To track recent advances in weak supervision, please follow this repo.

πŸ”§ Installation

[1] Install anaconda: Instructions here: https://www.anaconda.com/download/

[2] Clone the repository:

git clone https://github.com/JieyuZ2/wrench.git
cd wrench

[3] Create virtual environment:

conda env create -f environment.yml
source activate wrench

If this not working or you want to use only a subset of modules of Wrench, check out this wiki page

πŸ”§ Available Datasets

The datasets can be downloaded via this.

or via command line

pip install gdown 
gdown https://drive.google.com/uc?id=19wMFmpoo_0ORhBzB6n16B1nRRX508AnJ
unzip datasets.zip
rm datasets.zip 

A documentation of dataset format and usage can be found in this wiki-page

classification:

Name Task # class # LF # train # validation # test data source LF source
Census income clasification 2 83 10083 5561 16281 link link
Youtube spam clasification 2 10 1586 120 250 link link
SMS spam clasification 2 73 4571 500 500 link link
IMDB sentiment clasification 2 8 20000 2500 2500 link link
Yelp sentiment clasification 2 8 30400 3800 3800 link link
AGNews topic clasification 4 9 96000 12000 12000 link link
TREC question classification 6 68 4965 500 500 link link
Spouse relation classification 2 9 22254 2801 2701 link link
SemEval relation classification 9 164 1749 200 692 link link
CDR bio relation classification 2 33 8430 920 4673 link link
Chemprot chemical relation classification 10 26 12861 1607 1607 link link
Commercial video frame classification 2 4 64130 9479 7496 link link
Tennis Rally video frame classification 2 6 6959 746 1098 link link
Basketball video frame classification 2 4 17970 1064 1222 link link

sequence tagging:

Name # class # LF # train # validation # test data source LF source
CoNLL-03 4 16 14041 3250 3453 link link
WikiGold 4 16 1355 169 170 link link
OntoNotes 5.0 18 17 115812 5000 22897 link link
BC5CDR 2 9 500 500 500 link link
NCBI-Disease 1 5 592 99 99 link link
Laptop-Review 1 3 2436 609 800 link link
MIT-Restaurant 8 16 7159 500 1521 link link
MIT-Movies 12 7 9241 500 2441 link link

The detailed documentation is coming soon.

πŸ”§ Available Models

classification:

Model Model Type Reference Link to Wrench
Majority Voting Label Model -- link
Weighted Majority Voting Label Model -- link
Dawid-Skene Label Model link link
Data Progamming Label Model link link
MeTaL Label Model link link
FlyingSquid Label Model link link
Logistic Regression End Model -- link
MLP End Model -- link
BERT End Model link link
COSINE End Model link link
Denoise Joint Model link link

sequence tagging:

Model Model Type Reference Link to Wrench
Hidden Markov Model Label Model link link
Conditional Hidden Markov Model Label Model link link
LSTM-CNNs-CRF End Model link link
BERT-CRF End Model link link
LSTM-ConNet Joint Model link link
BERT-ConNet Joint Model link link

classification-to-sequence-tagging wrapper:

Wrench also provides a SeqLabelModelWrapper that adaptes label model for classification task to sequence tagging task.

πŸ”§ Quick examples

πŸ”§ Label model with parallel grid search for hyper-parameters

import logging
import numpy as np
import pprint

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel 
from wrench.evaluation import AverageMeter

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)


#### Specify the hyper-parameter search space for grid search
search_space = {
    'Snorkel': {
        'lr': np.logspace(-5, -1, num=5, base=10),
        'l2': np.logspace(-5, -1, num=5, base=10),
        'n_epochs': [5, 10, 50, 100, 200],
    }
}

#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)

#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
                             metric=target, direction='auto', search_space=search_space[label_model_name],
                             n_repeats=n_repeats, n_trials=n_trials, parallel=True)

#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
    model = label_model(**searched_paras)
    history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
    metric_value = model.test(test_data, target)
    meter.update(target=metric_value)

metrics = meter.get_results()
pprint.pprint(metrics)

For detailed guidance of grid_search, please check out this wiki page.

πŸ”§ Run a standard supervised learning pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)


#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target, 
                    patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

πŸ”§ Build a two-stage weak supervision pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)

#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)


#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label, 
                    device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label, 
          device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

πŸ”§ Procedural labeling function generator

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
    n_class=2,
    n_lfs=10,
    alpha=0.75, # mean accuracy
    beta=0.1, # mean propensity
    alpha_radius=0.2, # radius of accuracy
    beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)

#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)


#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')

πŸ”§ Contact

Contact person: Jieyu Zhang, [email protected]

Don't hesitate to send us an e-mail if you have any question.

We're also open to any collaboration!

πŸ”§ Contributing Dataset and Model

We sincerely welcome any contribution to the datasets or models!

Comments
  • ModuleNotFoundError: No module named 'tokenizations'

    ModuleNotFoundError: No module named 'tokenizations'

    Hi, I faced some problems when trying to install the library. I tried to use pip install ws-benchmark==1.1.2rc0 as suggested in the document, the installation was successful but when I run the code I faced the error ModuleNotFoundError: No module named 'tokenizations'. Then I tried to clone the repository and create the environment using conda env create -f environment.yml, but the installation failed due to the following error FileNotFoundError: [Errno 2] No such file or directory: '/home/naiqing/miniconda3/envs/wrench/lib/python3.6/site-packages/huggingface_hub-0.0.16-py3.8.egg'. Do you have ideas on what might cause the problem and how can I fix it?

    opened by Gnaiqing 12
  • Is there a limitation of using dataset for different algs?

    Is there a limitation of using dataset for different algs?

    Firstly, thank you for building this awesome benchmark. While I try the example with different datasets (e.g., I try astra with youtube dataset), I got some errors like this,

        loss = cross_entropy_with_probs(predict_l, batch['labels'].to(device))
    KeyError: 'labels'
    

    Can this be fixed?

    opened by mrbeann 8
  • Python Package Installation Fails

    Python Package Installation Fails

    Installing ws-benchmark python package fails due to dependency conflict (see stack trace below).

    Tested on system:

    • OS: ubuntu
    • Python: 3.8.13
    • Clean VE

    Command to replicate:

    • pip install ws-benchmark

    Stack Trace:

    ERROR: Cannot install ws-benchmark and ws-benchmark==1.1.1 because these package versions have conflicting dependencies.
    
    The conflict is caused by:
        ws-benchmark 1.1.1 depends on networkx==2.7
        snorkel 0.9.7 depends on networkx<2.4 and >=2.2
    
    To fix this you could try to:
    1. loosen the range of package versions you've specified
    2. remove package versions to allow pip attempt to solve the dependency conflict
    
    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
    
    opened by bradleyfowler123 4
  • Using Multiple GPUs

    Using Multiple GPUs

    Hi,

    Is it possible to use multiple GPUs for the experiments, or will it be in future releases? It would be a nice feature if it is not possible right now.

    Best regards.

    opened by tolgayan 4
  • Running scripts

    Running scripts

    Hi, I am trying to run some models on the IMDB dataset.

    MLP:

    import logging
    import torch
    import numpy as np
    from wrench.dataset import load_dataset
    from wrench.labelmodel import Snorkel
    from wrench.logging import LoggingHandler
    from wrench.search import grid_search
    from wrench.endmodel import EndClassifierModel
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
        #### Search Space
        search_space = {
            'optimizer_lr': np.logspace(-5, -1, num=5, base=10),
            'optimizer_weight_decay': np.logspace(-5, -1, num=5, base=10),
        }
    
        #### Initialize the model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam'
        )
    
        #### Search best hyper-parameters using validation set in parallel
        n_trials = 20
        n_repeats = 1
        searched_paras = grid_search(
            model,
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            direction='auto',
            search_space=search_space,
            n_repeats=n_repeats,
            n_trials=n_trials,
            parallel=True,
            device=device,
        )
    
    
        #### Run end model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam',
            **searched_paras
        )
        model.fit(
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            device=device
        )
    
        logger.info(model.predict(test_data).tolist())
    
        acc = model.test(test_data, 'acc')
        logger.info(f'end model (MLP) test acc: {acc}')
    
    

    for which I am getting the following output:

    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 20000/20000 [00:00<00:00, 902651.16it/s]
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2500/2500 [00:00<00:00, 852639.45it/s]
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2500/2500 [00:00<00:00, 829503.99it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 20000/20000 [1:42:45<00:00,  3.24it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2500/2500 [13:24<00:00,  3.11it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2500/2500 [13:50<00:00,  3.01it/s]
    [I 2021-10-23 22:24:36,807] A new study created in memory with name: no-name-9e4ad09c-ea4a-4ee8-80c2-7633429e4038
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/test.json
    2021-10-23 21:57:10 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:10:40 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:24:36 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:24:36 - label model test acc: 0.716
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:37<00:00, 37.48s/it]
    [I 2021-10-23 22:25:14,563] Trial 0 finished with value: 0.5012 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.0001}. Best is trial 0 with value: 0.5012.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:23<00:00, 23.70s/it]
    [I 2021-10-23 22:25:38,448] Trial 1 finished with value: 0.496 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.1}. Best is trial 0 with value: 0.5012.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:14<00:00, 14.53s/it]
    [I 2021-10-23 22:25:53,171] Trial 2 finished with value: 0.5004 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.001}. Best is trial 0 with value: 0.5012.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:43<00:00, 43.73s/it]
    [I 2021-10-23 22:26:37,071] Trial 3 finished with value: 0.5088 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.001}. Best is trial 3 with value: 0.5088.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:18<00:00, 18.85s/it]
    [I 2021-10-23 22:26:56,161] Trial 4 finished with value: 0.488 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:38<00:00, 38.81s/it]
    [I 2021-10-23 22:27:35,214] Trial 5 finished with value: 0.4948 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:38<00:00, 38.15s/it]
    [I 2021-10-23 22:28:13,614] Trial 6 finished with value: 0.5024 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.01}. Best is trial 3 with value: 0.5088.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:15<00:00, 15.47s/it]
    [I 2021-10-23 22:28:29,335] Trial 7 finished with value: 0.4996 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:22<00:00, 22.49s/it]
    [I 2021-10-23 22:28:52,093] Trial 8 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:40<00:00, 40.25s/it]
    [I 2021-10-23 22:29:32,594] Trial 9 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.0001}. Best is trial 3 with value: 0.5088.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:39<00:00, 39.06s/it]
    [I 2021-10-23 22:30:11,902] Trial 10 finished with value: 0.5116 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:43<00:00, 43.46s/it]
    [I 2021-10-23 22:30:55,531] Trial 11 finished with value: 0.4912 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:23<00:00, 23.41s/it]
    [I 2021-10-23 22:31:19,095] Trial 12 finished with value: 0.4956 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:22<00:00, 22.12s/it]
    [I 2021-10-23 22:31:41,374] Trial 13 finished with value: 0.492 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:15<00:00, 15.78s/it]
    [I 2021-10-23 22:31:57,283] Trial 14 finished with value: 0.5044 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.0001}. Best is trial 10 with value: 0.5116.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:37<00:00, 37.28s/it]
    [I 2021-10-23 22:32:34,728] Trial 15 finished with value: 0.488 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:16<00:00, 16.04s/it]
    [I 2021-10-23 22:32:50,934] Trial 16 finished with value: 0.4924 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:19<00:00, 19.65s/it]
    [I 2021-10-23 22:33:10,753] Trial 17 finished with value: 0.5156 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}. Best is trial 17 with value: 0.5156.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:15<00:00, 15.41s/it]
    [I 2021-10-23 22:33:26,345] Trial 18 finished with value: 0.5068 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.001}. Best is trial 17 with value: 0.5156.
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:16<00:00, 16.75s/it]
    [I 2021-10-23 22:33:43,222] Trial 19 finished with value: 0.498 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.01}. Best is trial 17 with value: 0.5156.
    [TRAIN]:  15%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                               | 1499/10000 [00:21<02:04, 68.19steps/s, loss=4.02, val_acc=0.5, best_val_acc=0.508, best_step=500]
    2021-10-23 22:33:43 - [END: BEST VAL / PARAMS] Best value: 0.5156, Best paras: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}
    2021-10-23 22:33:43 - 
    ==========[hyper parameters]==========
    {
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.1,
            "weight_decay": 0.1
        }
    }
    ==========[backbone config]==========
    {
        "name": "MLP",
        "paras": {
            "hidden_size": 100,
            "dropout": 0.0
        }
    }
    
    2021-10-23 22:34:09 - [INFO] early stop @ step 1500!
    2021-10-23 22:34:09 - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    2021-10-23 22:34:09 - end model (MLP) test acc: 0.5004
    

    COSINE:

    import logging
    import torch
    from wrench.dataset import load_dataset
    from wrench.logging import LoggingHandler
    from wrench.labelmodel import Snorkel
    from wrench.endmodel import Cosine
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
    
        # COSINE
        model = Cosine(
            teacher_update=100,
            margin=1.0,
            thresh=0.6,
            lr=1e-5,
            mu=1.0,
            lamda=0.05,
            backbone='BERT',
            backbone_model_name=bert_model_name,
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
        )
    
        model.fit(dataset_train=train_data,
                  dataset_valid=valid_data,
                  y_train=aggregated_hard_labels,
                  evaluation_step=10,
                  metric='acc',
                  patience=50,
                  device=device)
    
        acc = model.test(test_data, 'acc')
    
        logger.info(model.predict(test_data))
    
        logger.info(f'end model (COSINE) test acc: {acc}')
    

    for which I am getting the following output:

    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 20000/20000 [00:00<00:00, 899119.81it/s]
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2500/2500 [00:00<00:00, 423667.07it/s]
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2500/2500 [00:00<00:00, 802645.44it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 20000/20000 [1:47:44<00:00,  3.09it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2500/2500 [14:22<00:00,  2.90it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2500/2500 [13:33<00:00,  3.07it/s] 
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    [TRAIN] COSINE pretrain stage:   5%|β–Š               | 509/10000 [21:19<6:37:40,  2.51s/steps, loss=0.605, val_acc=0.5, best_val_acc=0.5, best_step=10]
    [TRAIN] COSINE distillation stage:   0%|                                                                                 | 0/10000 [03:05<?, ?steps/s]
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:14 - loading data from ../datasets/imdb/test.json
    2021-10-23 22:02:05 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:16:34 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:30:14 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:30:14 - label model test acc: 0.716
    2021-10-23 22:30:17 - 
    ==========[hyper parameters]==========
    {
        "teacher_update": 100,
        "margin": 1.0,
        "mu": 1.0,
        "thresh": 0.6,
        "lamda": 0.05,
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.001,
            "weight_decay": 0.0
        }
    }
    ==========[backbone config]==========
    {
        "name": "BERT",
        "paras": {
            "model_name": "bert-base-cased",
            "max_tokens": 512,
            "fine_tune_layers": -1
        }
    }
    ==========[label model_config config]==========
    {
        "name": "MajorityVoting",
        "paras": {}
    }
    
    2021-10-23 22:51:52 - [INFO] early stop @ step 510!
    2021-10-23 22:55:20 - early stop because all the data are filtered!
    2021-10-23 22:56:06 - [1 1 1 ... 1 1 1]
    2021-10-23 22:56:06 - end model (COSINE) test acc: 0.5
    

    As can be seen for both models, label model test acc: 0.716 but end model (MLP) test acc: 0.5004 and end model (COSINE) test acc: 0.5.

    Am I doing something completely wrong? Could you please tell me if I am running the code correctly or is there some issue with hyperparameters?

    I would greatly appreciate if you could give me some advice. I would be very glad if you could include an example running script of the COSINE model as well.

    Thanks for the benchmark, I really appreciate it!

    opened by viheheb757 4
  • Reproducing Table 11 for classification

    Reproducing Table 11 for classification

    Thanks for this package @JieyuZ2 -- do you happen to have an orchestration script for reproducing Table 11 (and therefore Table 3) in the Wrench paper?

    opened by pmangg 3
  • No module named 'wrench.classification.self_training'

    No module named 'wrench.classification.self_training'

    Hi, I am trying to run run_denoise.py but I am getting the following error:

    Traceback (most recent call last):
      File "run_denoise.py", line 5, in <module>
        from wrench.classification import Denoise
      File "/gpfs/space/home/wrench/wrench/classification/__init__.py", line 4, in <module>
        from .self_training import LDSelfTrain, DDSelfTrain
    ModuleNotFoundError: No module named 'wrench.classification.self_training'
    

    Could you please add LDSelfTrain and DDSelfTrain classes?

    opened by andreaspung 3
  • Questions on the use of ground-truth labels for validation

    Questions on the use of ground-truth labels for validation

    Thanks for putting up the benchmark! This is really great work! It seems that both the label model and the end model use the ground-truth labels for validation. For example, the base label model uses the ground-truth labels of the validation set to calculate the class balance weights: https://github.com/JieyuZ2/wrench/blob/544119e781d010797cf153307aa1090361c99522/wrench/basemodel.py#L286 I have a few questions regarding this: (1) A valid baseline for the label models would be a classifier trained on the validation set with the weak labels of LFs as features and the ground-truth labels as the target. Given that the validation set for most datasets is actually not small, I feel the trained model might be a pretty strong baseline compared to other unsupervised label models. (2) Similar to how we combine the weak labels on the training set to get aggregated labels, we could also get aggregated labels for the validation set. Then, the aggregated labels instead of the ground-truth labels of the validation set could be used for validation purposes for the end model. Wouldn't this be a more realistic setting? Especially considering that the proposal of weak supervision is to replace human labeling with programmatical labeling.

    I appreciate any explanations. Thanks!

    opened by wurenzhi 2
  • Clarifying dataset download links

    Clarifying dataset download links

    Great work on the benchmark!

    Under the "Available Datasets" section on the main README, you provide 2 links for downloading the WRENCH datasets:

    One point of confusion is that expanded datasets found on the Google drive link are different than the direct download zip file. For example, classification/youtube/train.json on Google drive has 1686 instances while the zip file contains 1586 for the same file, matching the statistics reported on the README. Can you make the correct file download unambiguous in the documentation?

    opened by jason-fries 2
  • Fix retained probabilities

    Fix retained probabilities

    This pull request removes a bug which lead to the wrong probabilities being stored along with the predictions of each labeling function.

    Previously, all probabilities (2d tensor of size batch by classes) were saved alongside the class predictions. However, what was supposed to be saved is the probability associated with each prediction of the model.

    opened by benbo 2
  • New Release

    New Release

    Hi! Love the repo, super useful so far and really easy interface to use. Thanks for putting it together!

    I was wondering if there were plans to cut another release any time soon? We use the v1.0 tag for making sure the version is consistent across multiple builds. Noticed a few bug fixes and QOL improvements since the last release, and those would be nice to have marked at a new tag.

    opened by rsmith49 2
  • Numba 0.43 doesn't work with newer Python versions

    Numba 0.43 doesn't work with newer Python versions

    The numba package 0.43, specified here, doesn't work with Python 3.9. Upgrading the package to the latest version (0.54) resolves the issue. Traceback:

    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/__init__.py:3: UserWarning: The module `llvmlite.llvmpy` is deprecated and will be removed in the future.
      warnings.warn(
    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/core.py:8: UserWarning: The module `llvmlite.llvmpy.core` is deprecated and will be removed in the future. Equivalent functionality is provided by `llvmlite.ir`.
      warnings.warn(
    Traceback (most recent call last):
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/__init__.py", line 1, in <module>
        from .dawid_skene import DawidSkene
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/dawid_skene.py", line 6, in <module>
        from numba import njit, prange
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/__init__.py", line 25, in <module>
        from .decorators import autojit, cfunc, generated_jit, jit, njit, stencil
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/decorators.py", line 12, in <module>
        from .targets import registry
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/registry.py", line 5, in <module>
        from . import cpu
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/cpu.py", line 9, in <module>
        from numba import _dynfunc, config
    ImportError: /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/_dynfunc.cpython-39-x86_64-linux-gnu.so: undefined symbol: _PyObject_GC_UNTRACK
    
    opened by susuzheng 0
  • COSINE for token classification?

    COSINE for token classification?

    Hi,

    I would like to know whether the code for cosine weak-supervision technique is already capable of performing token classification? Or else what changes should I need to do to build a weakly-supervised training pipeline using some weakly-labeled and unlabeled datasets?

    opened by KrishnanJothi 0
  • Balance in Dawid Skene is obsolete.

    Balance in Dawid Skene is obsolete.

    https://github.com/JieyuZ2/wrench/blob/6d8397956533fc6c2fe50e93fcfe0a2303bdd05f/wrench/labelmodel/dawid_skene.py#L55

    I realized this balance variable is used nowhere in this file. If it is intended, I think it should be removed from input parameters.

    opened by ch-shin 1
  • Balance sum to 1

    Balance sum to 1

    https://github.com/JieyuZ2/wrench/blob/ab717ac26a76649c8fdb946a28dffe7e682c80ba/wrench/basemodel.py#L303

    Hi, I find a minor issue that the class prior computed by this function does not sum to 1. Hope you can revise it.

    opened by Gnaiqing 0
  • about COSINE endmodel

    about COSINE endmodel

    Hi @JieyuZ2 and @yinxiangshi , I am trying to run the COSINE endmodel but I have some troubles in reproducing the results in COSINE paper. Although I tried to use the suggested hyperparameters I still get marginal benefit from wrench, and I'm not sure where is wrong. Can you share the scripts you used when evaluating COSINE? Thanks.

    opened by Gnaiqing 0
  • Recommended parameters to use for each algorithms and datasets.

    Recommended parameters to use for each algorithms and datasets.

    I've tried several combinations of different algorithms and datasets, but I found it's hard to get results similar to the paper. I suspect this is due to inappropriate parameter settings, so, I think it will be great if this repo can provide some recommended parameters. (Especially for the newly added algorithms, it's hard to judge if it get the right results)

    opened by mrbeann 0
Releases(v1.1)
  • v1.1(Nov 9, 2021)

    What's new:

    • A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
    • A new EndClassifierModel model which unifies all the classification backbones
    • Two new datasets on image classification
    • Support torch native amp for inference in the validation step
    • Support training on multiple GPUS via torch's DistributedDataParallel and the new parallel_fit function
    • fixed some bugs
    Source code(tar.gz)
    Source code(zip)
  • v1.0(Sep 7, 2021)

Owner
Jieyu Zhang
CS PhD
Jieyu Zhang
OBG-FCN - implementation of 'Object Boundary Guided Semantic Segmentation'

OBG-FCN This repository is to reproduce the implementation of 'Object Boundary Guided Semantic Segmentation' in http://arxiv.org/abs/1603.09742 Object

Jiu XU 3 Mar 11, 2019
Neural Architecture Search Powered by Swarm Intelligence 🐜

Neural Architecture Search Powered by Swarm Intelligence 🐜 DeepSwarm DeepSwarm is an open-source library which uses Ant Colony Optimization to tackle

288 Oct 28, 2022
Implementation of our paper "DMT: Dynamic Mutual Training for Semi-Supervised Learning"

DMT: Dynamic Mutual Training for Semi-Supervised Learning This repository contains the code for our paper DMT: Dynamic Mutual Training for Semi-Superv

Zhengyang Feng 120 Dec 30, 2022
naked is a Python tool which allows you to strip a model and only keep what matters for making predictions.

naked is a Python tool which allows you to strip a model and only keep what matters for making predictions. The result is a pure Python function with no third-party dependencies that you can simply c

Max Halford 24 Dec 20, 2022
Code for our paper "MG-GAN: A Multi-Generator Model Preventing Out-of-Distribution Samples in Pedestrian Trajectory Prediction" published at ICCV 2021.

MG-GAN: A Multi-Generator Model Preventing Out-of-Distribution Samples in Pedestrian Trajectory Prediction This repository contains the code for the p

Sven 30 Jan 05, 2023
Python implementation of "Elliptic Fourier Features of a Closed Contour"

PyEFD An Python/NumPy implementation of a method for approximating a contour with a Fourier series, as described in [1]. Installation pip install pyef

Henrik Blidh 71 Dec 09, 2022
TextBPN Adaptive Boundary Proposal Network for Arbitrary Shape Text Detection

TextBPN Adaptive Boundary Proposal Network for Arbitrary Shape Text DetectionοΌ› Accepted by ICCV2021. Note: The complete code (including training and t

S.X.Zhang 84 Dec 13, 2022
Monk is a low code Deep Learning tool and a unified wrapper for Computer Vision.

Monk - A computer vision toolkit for everyone Why use Monk Issue: Want to begin learning computer vision Solution: Start with Monk's hands-on study ro

Tessellate Imaging 507 Dec 04, 2022
Several simple examples for popular neural network toolkits calling custom CUDA operators.

Neural Network CUDA Example Several simple examples for neural network toolkits (PyTorch, TensorFlow, etc.) calling custom CUDA operators. We provide

WeiYang 798 Jan 01, 2023
【ACMMM 2021】DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning

DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning (ACMMM 2021) Overview We release the code of the DSANet (Dynamic S

Wenhao Wu 46 Dec 27, 2022
Revisiting Global Statistics Aggregation for Improving Image Restoration

Revisiting Global Statistics Aggregation for Improving Image Restoration Xiaojie Chu, Liangyu Chen, Chengpeng Chen, Xin Lu Paper: https://arxiv.org/pd

MEGVII Research 128 Dec 24, 2022
Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network

DeepCDR Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network This work has been accepted to ECCB2020 and was also published in the

Qiao Liu 50 Dec 18, 2022
Code for ICE-BeeM paper - NeurIPS 2020

ICE-BeeM: Identifiable Conditional Energy-Based Deep Models Based on Nonlinear ICA This repository contains code to run and reproduce the experiments

Ilyes Khemakhem 65 Dec 22, 2022
ICML 21 - Voice2Series: Reprogramming Acoustic Models for Time Series Classification

Voice2Series-Reprogramming Voice2Series: Reprogramming Acoustic Models for Time Series Classification International Conference on Machine Learning (IC

49 Jan 03, 2023
Towards Representation Learning for Atmospheric Dynamics (AtmoDist)

Towards Representation Learning for Atmospheric Dynamics (AtmoDist) The prediction of future climate scenarios under anthropogenic forcing is critical

Sebastian Hoffmann 4 Dec 15, 2022
Code for Learning Manifold Patch-Based Representations of Man-Made Shapes, in ICLR 2021.

LearningPatches | Webpage | Paper | Video Learning Manifold Patch-Based Representations of Man-Made Shapes Dmitriy Smirnov, Mikhail Bessmeltsev, Justi

Dima Smirnov 22 Nov 14, 2022
(Python, R, C/C++) Isolation Forest and variations such as SCiForest and EIF, with some additions (outlier detection + similarity + NA imputation)

IsoTree Fast and multi-threaded implementation of Extended Isolation Forest, Fair-Cut Forest, SCiForest (a.k.a. Split-Criterion iForest), and regular

141 Dec 29, 2022
dyld_shared_cache processing / Single-Image loading for BinaryNinja

Dyld Shared Cache Parser Author: cynder (kat) Dyld Shared Cache Support for BinaryNinja Without any of the fuss of requiring manually loading several

cynder 76 Dec 28, 2022
The modify PyTorch version of Siam-trackers which are speed-up by TensorRT.

SiamTracker-with-TensorRT The modify PyTorch version of Siam-trackers which are speed-up by TensorRT or ONNX. [Updating...] Examples demonstrating how

9 Dec 13, 2022
Code Impementation for "Mold into a Graph: Efficient Bayesian Optimization over Mixed Spaces"

Code Impementation for "Mold into a Graph: Efficient Bayesian Optimization over Mixed Spaces" This repo contains the implementation of GEBO algorithm.

Jaeyeon Ahn 2 Mar 22, 2022