Powerful unsupervised domain adaptation method for dense retrieval.

Overview

Generative Pseudo Labeling (GPL)

GPL is an unsupervised domain adaptation method for training dense retrievers. It is based on query generation and pseudo labeling with powerful cross-encoders. To train a domain-adapted model, it needs only the unlabeled target corpus and can achieve significant improvement over zero-shot models.

For more information, checkout our publication:

Installation

One can either install GPL via pip

pip install gpl

or via git clone

git clone https://github.com/UKPLab/gpl.git && cd gpl
pip install -e .

Usage

GPL accepts data in the BeIR-format. For example, we can download the FiQA dataset hosted by BeIR:

wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip
unzip fiqa.zip
head -n 2 fiqa/corpus.jsonl  # One can check this data format. Actually GPL only need this `corpus.jsonl` as data input for training.

Then we can either use the python -m function to run GPL training directly:

export dataset="fiqa"
python -m gpl.train \
    --path_to_generated_data "generated/$dataset" \
    --base_ckpt 'distilbert-base-uncased' \
    --batch_size_gpl 32 \
    --gpl_steps 140000 \
    --output_dir "output/$dataset" \
    --evaluation_data "./$dataset" \
    --evaluation_output "evaluation/$dataset" \
    --generator "BeIR/query-gen-msmarco-t5-base-v1" \
    --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
    --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
    --qgen_prefix "qgen" \
    --do_evaluation \
    # --use_amp   # Use this for efficient training if the machine supports AMP

# One can run `python -m gpl.train --help` for the information of all the arguments
# To reproduce the experiments in the paper, set `base_ckpt` to "GPL/msmarco-distilbert-margin-mse" (https://huggingface.co/GPL/msmarco-distilbert-margin-mse)

or import GPL's trainining method in a python script:

import gpl

dataset = 'fiqa'
gpl.train(
    path_to_generated_data=f"generated/{dataset}",
    base_ckpt='distilbert-base-uncased',  
    # base_ckpt='GPL/msmarco-distilbert-margin-mse',  # The starting checkpoint of the experiments in the paper
    batch_size_gpl=32,
    gpl_steps=140000,
    output_dir=f"output/{dataset}",
    evaluation_data=f"./{dataset}",
    evaluation_output=f"evaluation/{dataset}",
    generator="BeIR/query-gen-msmarco-t5-base-v1",
    retrievers=["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
    cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
    qgen_prefix="qgen",
    do_evaluation=True,
    # --use_amp   # One can use this flag for enabling the efficient float16 precision
)

How does GPL work?

The workflow of GPL is shown as follows:

  1. GPL first use a seq2seq (we use BeIR/query-gen-msmarco-t5-base-v1 by default) model to generate queries_per_passage queries for each passage in the unlabeled corpus. The query-passage pairs are viewed as positive examples for training.

    Result files (under path $path_to_generated_data): (1) ${qgen}-qrels/train.tsv, (2) ${qgen}-queries.jsonl and also (3) corpus.jsonl (copied from $evaluation_data/);

  2. Then, it runs negative mining with the generated queries as input on the target corpus. The mined passages will be viewed as negative examples for training. One can specify any dense retrievers (SBERT or Huggingface/transformers checkpoints, we use msmarco-distilbert-base-v3 + msmarco-MiniLM-L-6-v3 by default) or BM25 to the argument retrievers as the negative miner.

    Result file (under path $path_to_generated_data): hard-negatives.jsonl;

  3. Finally, it does pseudo labeling with the powerful cross-encoders (we use cross-encoder/ms-marco-MiniLM-L-6-v2 by default.) on the query-passage pairs that we have so far (for both positive and negative examples).

    Result file (under path $path_to_generated_data): gpl-training-data.tsv. It contains (gpl_steps * batch_size_gpl) tuples in total.

Up to now, we have the actual training data ready. One can look at sample-data/generated/fiqa for a quick example about the data format. The very last step is to apply the MarginMSE loss to teach the student retriever to mimic the margin scores, CE(query, positive) - CE(query, negative) labeled by the teacher model (Cross-Encoder, CE).

Customized data

One can also replace/put the customized data for any intermediate step under the path $path_to_generated_data with the same name fashion. GPL will skip the intermediate steps by using these provided data.

Citation

If you use the code for evaluation, feel free to cite our publication GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval:

@article{wang2021gpl,
    title = "GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval",
    author = "Kexin Wang and Nandan Thakur and Nils Reimers and Iryna Gurevych", 
    journal= "arXiv preprint arXiv:2112.07577",
    month = "4",
    year = "2021",
    url = "https://arxiv.org/abs/2112.07577",
}

Contact person and main contributor: Kexin Wang, [email protected]

https://www.ukp.tu-darmstadt.de/

https://www.tu-darmstadt.de/

Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.

This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.

Comments
  • Error while running the training script

    Error while running the training script

    2022-04-14 06:00:25] INFO [gpl.toolkit.pl.run:60] Begin pseudo labeling 0%| | 0/140000 [00:00<?, ?it/s] Traceback (most recent call last): File "/home/ec2-user/SageMaker/gpl/gpl/toolkit/pl.py", line 63, in run batch = next(hard_negative_iterator) File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 530, in next data = self._next_data() File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 569, in _next_data index = self._next_index() # may raise StopIteration File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in _next_index return next(self._sampler_iter) # may raise StopIteration StopIteration

    opened by kingafy 3
  • Loss function

    Loss function

    Is it a typo of having the minus sign "-" in the MarginMSE loss function in Equation (1) in the GPL paper?

    There should be no minus sign "-". Because the model should minimize the MSE(delta_teacher, delta_student), not maximize it. I checked the released code of GPL, the loss function is without the minus sign "-".

    image image
    opened by dli1 2
  • GPU speedup

    GPU speedup

    I recon this is more of a generic question for TSADE + GPL (or any transformer used) , but can you use GPU by simply doing something like gpl.to(device)?

    opened by ahadda5 1
  • [KTLO-6] Hints for missing evaluation data

    [KTLO-6] Hints for missing evaluation data

    The previous code does not give enough hint about missing evaluation data

    • gpl/toolkit/evaluation.py: Added checking for missing evaluation data
    • tests/unit/conftest.py: Separated sbert and sbert_path fixtures
    • tests/unit/test_eval.py: Added test
    opened by kwang2049 0
  • [KTLO-5] batch size larger than data size

    [KTLO-5] batch size larger than data size

    The previous code did not check whether the batch size is larger than the number of data points (or number of generated queries) in PseudoLabeler.run

    • pl/toolkit/pl.py: Added check at the beginning of run about batch size vs data size
    • tests/unit/test_pl.py: Added test
    opened by kwang2049 0
  • [KTLO-4] OOM error in qgen

    [KTLO-4] OOM error in qgen

    Previous code does not detect OOM error in QGen, which might be due to large QPP or batch size

    modified: gpl/toolkit/qgen.py: Added try catch new file: tests/unit/test_qgen.py: Added test

    opened by kwang2049 0
  • [KTLO-3] OOM error in loadable checking

    [KTLO-3] OOM error in loadable checking

    The current version could not identify OOM error in loadable_by_sbert_oom, since OOM is also a runtime error and this loadable checking views all runtime errors as not loadable

    • modified: gpl/toolkit/sbert.py: Raise OOM error (runtime error)
    • modified: setup.py: Added pytest
    • new file: tests/unit/conftest.py: SBERT fixture
    • new file: tests/unit/test_sbert.py: Test OOM error case
    opened by kwang2049 0
  • [KTLO-0] New EES version and black formatting

    [KTLO-0] New EES version and black formatting

    • README.md: Hint of installing PyTorch correctly wrt. the CUDA version.
    • gpl/toolkit/beir.py: Black
    • gpl/toolkit/dataset.py: Black
    • gpl/toolkit/evaluation.py: Black
    • gpl/toolkit/log.py: Black
    • gpl/toolkit/loss.py: Black
    • gpl/toolkit/mine.py: Black
    • gpl/toolkit/mnrl.py: Black
    • gpl/toolkit/pl.py: Black
    • gpl/toolkit/qgen.py: Black
    • gpl/toolkit/reformat.py: Black
    • gpl/toolkit/rescale.py: Black
    • gpl/toolkit/resize.py: Black
    • gpl/toolkit/sbert.py: Black
    • gpl/train.py: Black
    • setup.py: Added protobuf, required by T5 and seems to be ignored by simply installing transformer; specified ees>=0.0.8 (where the es version is kept the same with that required by beir)
    opened by kwang2049 0
  • Should the leaning domain contain only assertion texts (like

    Should the leaning domain contain only assertion texts (like "Python is a high-level general-purpose programming language")?

    Hi. Should the leaning domain contain only assertion texts (like "Python is a high-level general-purpose programming language" in your example)? In your pipeline the first step is Query Generation: For a given text from our domain, we first use a T5 model that generates a possible query for the given text. E.g. when your text is “Python is a high-level general-purpose programming language”, the model might generate a query like “What is Python”. You can find various query generators on our doc2query-hub. Does that mean that texts which couldn't be converted into queries (e.g. "Investment consulting for legal entities and individuals.") cannot be used for training?

    opened by edgar2597 0
  • GPL for sentence embedding tasks?

    GPL for sentence embedding tasks?

    In the provided examples GPL us used for semantic search tasks: given a query, relevant results should be retrieved. Is it also the recommended approach to get meaningful embeddings / bi-encoders, or is it better to use TSDAE?

    opened by hanshupe 2
  • Guidance on gpl_stapes, new_size and batch_size_gpl

    Guidance on gpl_stapes, new_size and batch_size_gpl

    Hello,

    I am looking for some guidance on below parameters of gpl.train().

    • gpl_stapes - Do we need such a huge value of 140000 for corpus of size 1300?
    • new_size
    • batch_size_gpl - would it help to speed up the training if we keep this as 64 or 128? How to derive the values of these parameters based on dataset or corpus.jsonl?
    opened by MyBruso 0
  • TSDAE to GPL... Error on start

    TSDAE to GPL... Error on start

    I'm trying to go from my trained TSDAE and then apply GPL... However, keep getting errors.

    ! export dataset="hs_resume_tsdae_gpl_mini" 
    ! python -m gpl.train \
        --path_to_generated_data "generated/$dataset" \
        --base_ckpt "/Users/cfeld/Desktop/dev/trajectory/finetuning/gpl/outputs/tsdae/MiniLM-L6-H384-uncased-model" \
        --gpl_score_function "dot" \
        --batch_size_gpl 34 \
        --gpl_steps 100 \
        --queries_per_passage 1 \
        --output_dir "output/$dataset" \
        --evaluation_data "./$dataset" \
        --evaluation_output "evaluation/$dataset" \
        --generator "BeIR/query-gen-msmarco-t5-base-v1" \
        --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
        --retriever_score_functions "cos_sim" "cos_sim" \
        --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
        --use_train_qrels
    

    However, I'm getting this error:

    2022-09-12 17:37:44 - Loading faiss.
    2022-09-12 17:37:44 - Successfully loaded faiss.
    /opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py:127: RuntimeWarning: 'gpl.train' found in sys.modules after import of package 'gpl', but prior to execution of 'gpl.train'; this may result in unpredictable behaviour
      warn(RuntimeWarning(msg))
    [2022-09-12 17:37:44] INFO [gpl.train.train:79] Corpus does not exist in generated/. Now clone the one from the evaluation path ./
    [2022-09-12 17:37:44] WARNING [gpl.train.train:106] Found `qgen_prefix` is not None. By setting `use_train_qrels == True`, the `qgen_prefix` will not be used
    [2022-09-12 17:37:44] INFO [gpl.train.train:113] Loading qrels and queries from labeled data under the path of `evaluation_data`
    Traceback (most recent call last):
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/site-packages/gpl/train.py", line 250, in <module>
        train(**vars(args))
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/site-packages/gpl/train.py", line 114, in train
        assert 'qrels' in os.listdir(evaluation_data) and 'queries.jsonl' in os.listdir(evaluation_data)
    AssertionError
    

    Perhaps my folder structure isn't quite right? I've tried all kinds of combos... Folder: corpus.jsonl evaluation - corpus.jsonl - hs_resume_tsdae_gpl_mini -- corpus.jsonl generated - corpus.jsonl - hs_resume_tsdae_gpl_mini -- corpus.jsonl hs_resume_tsdae_gpl_mini - corpus.jsonl output - hs_resume_tsdae_gpl_mini

    opened by christophermfeld 1
  • Evaluation data format

    Evaluation data format

    Hi,

    1/ How should the evaluation data format be as passed in the evaluation_data argument? Could you provide me some example of evaluation data and how it should be formatted?

    2/ How does the evaluation work on these data? What are the tests passed and labels used?

    Thanks!

    opened by Matthieu-Tinycoaching 0
  • RuntimeError: CUDA out of memory

    RuntimeError: CUDA out of memory

    Hi,

    When trying to generate intermediate results with the following command:

    dataset = 'tiny'
    gpl.train(
        path_to_generated_data=f"generated/{dataset}",
        base_ckpt='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',  
        # base_ckpt='GPL/msmarco-distilbert-margin-mse',  # The starting checkpoint of the experiments in the paper
        gpl_score_function="dot",
        # Note that GPL uses MarginMSE loss, which works with dot-product
        batch_size_gpl=32,
        gpl_steps=140000,
        new_size=-1,
        # Resize the corpus to `new_size` (|corpus|) if needed. When set to None (by default), the |corpus| will be the full size. When set to -1, the |corpus| will be set automatically: If QPP * |corpus| <= 250K, |corpus| will be the full size; else QPP will be set 3 and |corpus| will be set to 250K / 3
        queries_per_passage=-1,
        # Number of Queries Per Passage (QPP) in the query generation step. When set to -1 (by default), the QPP will be chosen automatically: If QPP * |corpus| <= 250K, then QPP will be set to 250K / |corpus|; else QPP will be set 3 and |corpus| will be set to 250K / 3
        output_dir=f"output/{dataset}",
        evaluation_data=f"./{dataset}",
        evaluation_output=f"evaluation/{dataset}",
        generator="BeIR/query-gen-msmarco-t5-large-v1",
        retrievers=["msmarco-distilbert-base-tas-b", "msmarco-MiniLM-L6-cos-v5"],
        retriever_score_functions=["dot", "cos_sim"],
        # Note that these two retriever model work with cosine-similarity
        cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
        qgen_prefix="qgen",
        # This prefix will appear as part of the (folder/file) names for query-generation results: For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
        do_evaluation=True,
        use_amp=True   # One can use this flag for enabling the efficient float16 precision
    )
    

    I got the following error:

    2022-08-26 11:55:08 - Loading faiss with AVX2 support.
    2022-08-26 11:55:08 - Could not load library with AVX2 support due to:
    ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
    2022-08-26 11:55:08 - Loading faiss.
    2022-08-26 11:55:08 - Successfully loaded faiss.
    [2022-08-26 11:55:10] INFO [gpl.train.train:79] Corpus does not exist in generated/tiny. Now clone the one from the evaluation path ./tiny
    [2022-08-26 11:55:10] INFO [gpl.train.train:84] Automatically set `new_size` to 83334
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 277639.61it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:10] WARNING [gpl.toolkit.resize.resize:19] `new_size` should be smaller than the corpus size
    [2022-08-26 11:55:10] INFO [gpl.toolkit.resize.resize:41] Resized the corpus in ./tiny to generated/tiny with new size 83334
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 321974.74it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:10] INFO [gpl.train.train:99] Automatically set `queries_per_passage` to 59
    [2022-08-26 11:55:10] INFO [gpl.train.train:125] No generated queries found. Now generating it
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 308459.11it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:20] INFO [beir.generation.models.auto_model.__init__:16] Use pytorch device: cuda
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:40] Starting to Generate 59 Questions Per Passage using top-p (nucleus) sampling...
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:41] Params: top_p = 0.95
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:42] Params: top_k = 25
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:43] Params: max_length = 64
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:44] Params: ques_per_passage = 59
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:45] Params: batch size = 32
    pas:   0%|                                                                                                                                                                                          | 0/133 [00:00<?, ?it/s]
    Traceback (most recent call last):
      File "/home/matthieu/Tinycoaching/GPL/v.0.1.0/gpl_query_generation.py", line 316, in <module>
        gpl.train(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/gpl/train.py", line 127, in train
        qgen(path_to_generated_data, path_to_generated_data, generator_name_or_path=generator, ques_per_passage=queries_per_passage, bsz=batch_size_generation, qgen_prefix=qgen_prefix)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/gpl/toolkit/qgen.py", line 23, in qgen
        generator.generate(corpus, output_dir=output_dir, ques_per_passage=ques_per_passage, prefix=prefix, batch_size=bsz)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/beir/generation/generate.py", line 54, in generate
        queries = self.model.generate(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/beir/generation/models/auto_model.py", line 28, in generate
        outs = self.model.generate(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
        return func(*args, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/generation_utils.py", line 1326, in generate
        return self.sample(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/generation_utils.py", line 1944, in sample
        outputs = self(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1639, in forward
        decoder_outputs = self.decoder(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1035, in forward
        layer_outputs = layer_module(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 692, in forward
        cross_attention_outputs = self.layer[1](
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 606, in forward
        attention_output = self.EncDecAttention(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 509, in forward
        scores = torch.matmul(
    RuntimeError: CUDA out of memory. Tried to allocate 584.00 MiB (GPU 0; 23.70 GiB total capacity; 20.69 GiB already allocated; 587.94 MiB free; 20.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    

    My corpus consists of small paragraphs of 3-4 lines and I used use_amp option. How could I deal with it?

    opened by Matthieu-Tinycoaching 1
Releases(v0.1.4)
  • v0.1.4(Sep 29, 2022)

  • v0.1.3(Sep 26, 2022)

    Previously, there was a conflict between easy_elasticsearch and beir on the dependency of elasticsearch:

    • easy_elasticsearch requires elasticsearch==7.12.1 while
    • beir requires elasticserch==7.9.1

    In the lastest version of easy_elasticsearch, the requirements have been changed to solve this issue. Here we update gpl to install this version (easy_elasticsearch==0.0.9). Another update of easy_elasticsearch==0.0.9 is that it has solved the issue that ES could return empty results (due to refresh is not called for indexing)

    Source code(tar.gz)
    Source code(zip)
    gpl-0.1.3-py3-none-any.whl(28.79 KB)
    gpl-0.1.3.tar.gz(22.88 KB)
  • v0.1.0(Apr 19, 2022)

    Updated paper, accepted by NAACL 2022

    The GPL paper has been accepted by NAACL 2022! Major updates:

    • Improved the setting: Down-sampled the corpus if it is too large; calculate the number of generated queries according to the corpus size;
    • Added more analysis about the influence of the number of generated queries: Small corpus needs more queries;
    • Added results on the full 18 BeIR datasets: The conclusions remain the same, while we also tried training GPL on top of the power TAS-B model and achieved new improvements.

    Automatic hyper-parameter

    Previously, we use the whole corpus and number of generated queries = 3, no matter the corpus size. This actually results in a very bad training efficiency for large corpus. In the new version, we automatically set these two hyper-parameters by meeting the standard: the total number of generated queries = 250K.

    In detail, we first set the queries_per_passage >= 3 and uniformly down-sample the corpus if 3 × |C| > 250K, where |C| is the corpus size; then we calculate queries_per_passage = 250K/|C|. For example, the queries_per_passage values for FiQA (original size = 57.6K) and Robust04 (original size = 528.2K) are 5 and 3, resp. and the Robust04 corpus is down-sampled to 83.3K.

    Released checkpoints (TAS-B ones)

    We now release the pre-trained GPL models via the https://huggingface.co/GPL. They also include the power GPL models trained on top of TAS-B.

    Source code(tar.gz)
    Source code(zip)
    gpl-0.1.0-py3-none-any.whl(27.99 KB)
    gpl-0.1.0.tar.gz(22.13 KB)
  • v0.0.9(Jan 11, 2022)

    Fixed bug of max.-sequence-length mismatch between student and teacher

    Previously, the teacher (i.e. the cross-encoder) got the input of the concatenation of query and document texts and had no limits of max. sequence length (cf. here and here). However, the students actually had the limits of max. sequence length on both query texts and document texts separately. This causes the mismatch between the information which can be seen by the student and the teacher models.

    In the new release, we fixed this by doing "retokenization": Right before pseudo labeling, we let the tokenizer of the teacher model tokenize the query texts and the document texts also separately and then decode the results (token IDs) back into the texts again. The resulting texts will meet the same max.-sequence-length requirements as the student model does and thus fix this bug.

    Keep full precision of the pseudo labels

    Previously, we saved the pseudo labels from PyTorch's tensors directly, which would not give the full precision. Now we have fixed this by doing labels.tolist() right before the data dumping. This actually would not influence a lot, since previously it kept 6-digit precision and was high enough.

    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.9-py3-none-any.whl(23.56 KB)
    gpl-0.0.9.tar(18.38 KB)
  • v0.0.8(Dec 20, 2021)

    Independent evaluation and k_values supported

    One can now run the gpl.toolkit.evaluation directly. Previously, it was only possible as part of the whole gpl.train workflow. Please check this example for more details.

    And we have also added argument k_values in gpl.toolkit.evaluation.evaluate. This is for specifying the K values in "[email protected]", "[email protected]", etc.

    Fixed bugs & use load_sbert in mnrl and evaluation

    Now almost all methods that require a separation token has this argument called sep (previously it was fixed as a blank token " "). Two exceptions are mnrl (a loss function in SBERT repo, also the default training loss for the QGen method) and qgen, since they are from the BeIR repo (we will update the BeIR repo in the future if possible).

    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.8-py3-none-any.whl(23.12 KB)
    gpl-0.0.8.tar(17.96 KB)
  • v0.0.7(Dec 17, 2021)

    Rewrite SBERT loading

    Previously, GPL loads starting checkpoints (--base_ckpt) by constructing SBERT model from scratch. This way would lose some information of the checkpoint (e.g. pooling and max_seq_length), and one needed to specify them carefully.

    Now we have created another method called load_sbert. It will use SentenceTransformer(base_ckpt) to load the checkpoint directly and do some checking & assertions. Loading from a Huggingface-format checkpoint (e.g. "distilbert-base-uncased") now is still possible for many cases as previous, but we do recommend users to load from a SBERT-format if possible, since it will be less likely to misuse the starting checkpoint.

    Reformatting examples

    In some cases, Huggingface-format checkpoint cannot be loaded directly by SBERT, e.g. "facebook/dpr-question_encoder-single-nq-base". This is because:

    1. Of course, they are not in SBERT-format but in Hugginface-format;
    2. And for Huggingface-format, SBERT can only work with the checkpoint with a Transformer layer as the last layer, i.e. the outputs should contain hidden states with shape (batch_size, sequence_length, hidden_dimenstion).

    To use these checkpoints, one needs to reformat them into SBERT-format. We have provided two examples/templates in the new toolkit source file, gpl/toolkit/reformat.py. Please refer to its readme here.

    Solved logging bug

    Previously, the logging in GPL is overridden by some other loggers and the formatting cannot display as we want. Now we have solved this by dealing with the root logger. And the new formatting will show many usefull details:

    fmt='[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s'
    
    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.7-py3-none-any.whl(22.72 KB)
    gpl-0.0.7.tar(17.81 KB)
Owner
Ubiquitous Knowledge Processing Lab
Ubiquitous Knowledge Processing Lab
Predicting Event Memorability from Contextual Visual Semantics

Predicting Event Memorability from Contextual Visual Semantics

0 Oct 06, 2021
Fast Learning of MNL Model From General Partial Rankings with Application to Network Formation Modeling

Fast-Partial-Ranking-MNL This repo provides a PyTorch implementation for the CopulaGNN models as described in the following paper: Fast Learning of MN

Xingjian Zhang 3 Aug 19, 2022
PSML: A Multi-scale Time-series Dataset for Machine Learning in Decarbonized Energy Grids

PSML: A Multi-scale Time-series Dataset for Machine Learning in Decarbonized Energy Grids The electric grid is a key enabling infrastructure for the a

Texas A&M Engineering Research 19 Jan 07, 2023
3D HourGlass Networks for Human Pose Estimation Through Videos

3D-HourGlass-Network 3D CNN Based Hourglass Network for Human Pose Estimation (3D Human Pose) from videos. This was my summer'18 research project. Dis

Naman Jain 51 Jan 02, 2023
Notebooks em Python para Métodos Eletromagnéticos

GeoSci Labs This is a repository of code used to power the notebooks and interactive examples for https://em.geosci.xyz and https://gpg.geosci.xyz. Th

Victor Cezar Tocantins 1 Nov 16, 2021
A Text Attention Network for Spatial Deformation Robust Scene Text Image Super-resolution (CVPR2022)

A Text Attention Network for Spatial Deformation Robust Scene Text Image Super-resolution (CVPR2022) https://arxiv.org/abs/2203.09388 Jianqi Ma, Zheto

MA Jianqi, shiki 104 Jan 05, 2023
CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP

CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP Andreas Fürst* 1, Elisabeth Rumetshofer* 1, Viet Tran1, Hubert Ramsauer1, Fei Tang3, Joh

Institute for Machine Learning, Johannes Kepler University Linz 133 Jan 04, 2023
Code repository for paper `Skeleton Merger: an Unsupervised Aligned Keypoint Detector`.

Skeleton Merger Skeleton Merger, an Unsupervised Aligned Keypoint Detector. The paper is available at https://arxiv.org/abs/2103.10814. A map of the r

北海若 48 Nov 14, 2022
A rough implementation of the paper "A Steering Algorithm for Redirected Walking Using Reinforcement Learning"

A rough implementation of the paper "A Steering Algorithm for Redirected Walking Using Reinforcement Learning"

Somnus `Chen 2 Jun 09, 2022
Half Instance Normalization Network for Image Restoration

HINet Half Instance Normalization Network for Image Restoration, based on https://github.com/megvii-model/HINet. Dependencies NumPy PyTorch, preferabl

Holy Wu 4 Jun 06, 2022
Rewrite ultralytics/yolov5 v6.0 opencv inference code based on numpy, no need to rely on pytorch

Rewrite ultralytics/yolov5 v6.0 opencv inference code based on numpy, no need to rely on pytorch; pre-processing and post-processing using numpy instead of pytroch.

炼丹去了 21 Dec 12, 2022
Code for the paper "Location-aware Single Image Reflection Removal"

Location-aware Single Image Reflection Removal The shown images are provided by the datasets from IBCLN, ERRNet, SIR2 and the Internet images. The cod

72 Dec 08, 2022
Model parallel transformers in Jax and Haiku

Mesh Transformer Jax A haiku library using the new(ly documented) xmap operator in Jax for model parallelism of transformers. See enwik8_example.py fo

Ben Wang 4.8k Jan 01, 2023
Official Pytorch implementation of "Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral)"

Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral): Official Project Webpage This repository provides the off

Kakao Enterprise Corp. 68 Dec 17, 2022
A real world application of a Recurrent Neural Network on a binary classification of time series data

What is this This is a real world application of a Recurrent Neural Network on a binary classification of time series data. This project includes data

Josep Maria Salvia Hornos 2 Jan 30, 2022
Tech Resources for Academic Communities

Free tech resources for faculty, students, researchers, life-long learners, and academic community builders for use in tech based courses, workshops, and hackathons.

Microsoft 2.5k Jan 04, 2023
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

24 Dec 31, 2022
Libraries, tools and tasks created and used at DeepMind Robotics.

dm_robotics: Libraries, tools, and tasks created and used for Robotics research at DeepMind. Package overview Package Summary Transformations Rigid bo

DeepMind 273 Jan 06, 2023
BBScan py3 - BBScan py3 With Python

BBScan_py3 This repository is forked from lijiejie/BBScan 1.5. I migrated the fo

baiyunfei 12 Dec 30, 2022
Hummingbird compiles trained ML models into tensor computation for faster inference.

Hummingbird Introduction Hummingbird is a library for compiling trained traditional ML models into tensor computations. Hummingbird allows users to se

Microsoft 3.1k Dec 30, 2022