Rax is a Learning-to-Rank library written in JAX

Related tags

Deep Learningrax
Overview

🦖 Rax: Composable Learning to Rank using JAX

Rax is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf implementations of ranking losses and metrics to be used with JAX. It provides the following functionality:

  • Ranking losses (rax.*_loss): rax.softmax_loss, rax.pairwise_logistic_loss, ...
  • Ranking metrics (rax.*_metric): rax.mrr_metric, rax.ndcg_metric, ...
  • Transformations (rax.*_t12n): rax.approx_t12n, rax.gumbel_t12n, ...

Ranking

A ranking problem is different from traditional classification/regression problems in that its objective is to optimize for the correctness of the relative order of a list of examples (e.g., documents) for a given context (e.g., a query). Rax provides support for ranking problems within the JAX ecosystem. It can be used in, but is not limited to, the following applications:

  • Search: ranking a list of documents with respect to a query.
  • Recommendation: ranking a list of items given a user as context.
  • Question Answering: finding the best answer from a list of candidates.
  • Dialogue System: finding the best response from a list of responses.

Synopsis

In a nutshell, given the scores and labels for a list of items, Rax can compute various ranking losses and metrics:

import jax.numpy as jnp
import rax

scores = jnp.asarray([2.2, -1.3, 5.4])  # output of a model.
labels = jnp.asarray([1., 0., 0.])      # indicates doc 1 is relevant.

rax.ndcg_metric(scores, labels)         # computes a ranking metric.
rax.pairwise_hinge_loss(scores, labels) # computes a ranking loss.

All of the Rax losses and metrics are purely functional and compose well with standard JAX transformations. Additionally, Rax provides ranking-specific transformations so you can build new ranking losses. An example is rax.approx_t12n, which can be used to transform any (non-differentiable) ranking metric into a differentiable loss. For example:

loss_fn = rax.approx_t12n(rax.ndcg_metric)
loss_fn(scores, labels)            # differentiable approx ndcg loss.
jax.grad(loss_fn)(scores, labels)  # computes gradients w.r.t. scores.

Examples

See the examples/ directory for complete examples on how to use Rax.

You might also like...
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.

[ICLR 2021] RAPID: A Simple Approach for Exploration in Reinforcement Learning This is the Tensorflow implementation of ICLR 2021 paper Rank the Episo

Rank 1st in the public leaderboard of ScanRefer (2021-03-18)
Rank 1st in the public leaderboard of ScanRefer (2021-03-18)

InstanceRefer InstanceRefer: Cooperative Holistic Understanding for Visual Grounding on Point Clouds through Instance Multi-level Contextual Referring

Code for
Code for "LoRA: Low-Rank Adaptation of Large Language Models"

LoRA: Low-Rank Adaptation of Large Language Models This repo contains the implementation of LoRA in GPT-2 and steps to replicate the results in our re

Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]
Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]

Rank & Sort Loss for Object Detection and Instance Segmentation The official implementation of Rank & Sort Loss. Our implementation is based on mmdete

This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

 COD-Rank-Localize-and-Segment (CVPR2021)
COD-Rank-Localize-and-Segment (CVPR2021)

COD-Rank-Localize-and-Segment (CVPR2021) Simultaneously Localize, Segment and Rank the Camouflaged Objects Full camouflage fixation training dataset i

Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train format

ttopt Description Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train (TT) format and maximu

ViViT: Curvature access through the generalized Gauss-Newton's low-rank structure

ViViT is a collection of numerical tricks to efficiently access curvature from the generalized Gauss-Newton (GGN) matrix based on its low-rank structure. Provided functionality includes computing

This is the solution for 2nd rank in Kaggle competition: Feedback Prize - Evaluating Student Writing.

Feedback Prize - Evaluating Student Writing This is the solution for 2nd rank in Kaggle competition: Feedback Prize - Evaluating Student Writing. The

Comments
  • Add Kendall's tau

    Add Kendall's tau

    Kendall's tau (https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient) is a ranking metric which we could add to Rax. We can reference the SciPy version of Kendall's tau (https://github.com/scipy/scipy/blob/v1.9.0/scipy/stats/_stats_py.py#L5015-L5222) as inspiration or comparison for the Rax implementation.

    Expected usage:

    scores = jnp.array([0., 1., 2.])
    labels = jnp.array([1., 2., 0.])
    rax.kendalltau_metric(scores, labels)  # should produce Kendall's tau-b.
    
    enhancement 
    opened by rjagerman 0
Releases(v0.2.0)
Owner
Google
Google ❤️ Open Source
Google
Visual Memorability for Robotic Interestingness via Unsupervised Online Learning (ECCV 2020 Oral and TRO)

Visual Interestingness Refer to the project description for more details. This code based on the following paper. Chen Wang, Yuheng Qiu, Wenshan Wang,

Chen Wang 36 Sep 08, 2022
Implementation of "The Power of Scale for Parameter-Efficient Prompt Tuning"

Prompt-Tuning Implementation of "The Power of Scale for Parameter-Efficient Prompt Tuning" Currently, we support the following huggigface models: Bart

Andrew Zeng 36 Dec 19, 2022
StackNet is a computational, scalable and analytical Meta modelling framework

StackNet This repository contains StackNet Meta modelling methodology (and software) which is part of my work as a PhD Student in the computer science

Marios Michailidis 1.3k Dec 15, 2022
🌊 Online machine learning in Python

In a nutshell River is a Python library for online machine learning. It is the result of a merger between creme and scikit-multiflow. River's ambition

OnlineML 4k Jan 02, 2023
Data labels and scripts for fastMRI.org

fastMRI+: Clinical pathology annotations for the fastMRI dataset The fastMRI dataset is a publicly available MRI raw (k-space) dataset. It has been us

Microsoft 51 Dec 22, 2022
A toolkit for developing and comparing reinforcement learning algorithms.

Status: Maintenance (expect bug fixes and minor updates) OpenAI Gym OpenAI Gym is a toolkit for developing and comparing reinforcement learning algori

OpenAI 29.6k Jan 08, 2023
Implementations of CNNs, RNNs, GANs, etc

Tensorflow Programs and Tutorials This repository will contain Tensorflow tutorials on a lot of the most popular deep learning concepts. It'll also co

Adit Deshpande 1k Dec 30, 2022
This is an official pytorch implementation of Fast Fourier Convolution.

Fast Fourier Convolution (FFC) for Image Classification This is the official code of Fast Fourier Convolution for image classification on ImageNet. Ma

pkumi 199 Jan 03, 2023
Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation

Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation

Bae, Gwangbin 95 Jan 04, 2023
This is the official github repository of the Met dataset

The Met dataset This is the official github repository of the Met dataset. The official webpage of the dataset can be found here. What is it? This cod

Nikolaos-Antonios Ypsilantis 35 Dec 17, 2022
An open-source Kazakh named entity recognition dataset (KazNERD), annotation guidelines, and baseline NER models.

Kazakh Named Entity Recognition This repository contains an open-source Kazakh named entity recognition dataset (KazNERD), named entity annotation gui

ISSAI 9 Dec 23, 2022
Learning Versatile Neural Architectures by Propagating Network Codes

Learning Versatile Neural Architectures by Propagating Network Codes Mingyu Ding, Yuqi Huo, Haoyu Lu, Linjie Yang, Zhe Wang, Zhiwu Lu, Jingdong Wang,

Mingyu Ding 36 Dec 06, 2022
Implementation for Paper "Inverting Generative Adversarial Renderer for Face Reconstruction"

StyleGAR TODO: add arxiv link Implementation of Inverting Generative Adversarial Renderer for Face Reconstruction TODO: for test Currently, some model

155 Oct 27, 2022
Train neural network for semantic segmentation (deep lab V3) with pytorch in less then 50 lines of code

Train neural network for semantic segmentation (deep lab V3) with pytorch in 50 lines of code Train net semantic segmentation net using Trans10K datas

17 Dec 19, 2022
Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides

Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides Project | This repo is the officia

CVSM Group - email: <a href=[email protected]"> 33 Dec 28, 2022
Python Auto-ML Package for Tabular Datasets

Tabular-AutoML AutoML Package for tabular datasets Tabular dataset tuning is now hassle free! Run one liner command and get best tuning and processed

Sagnik Roy 18 Nov 20, 2022
A PaddlePaddle version of Neural Renderer, refer to its PyTorch version

Neural 3D Mesh Renderer in PadddlePaddle A PaddlePaddle version of Neural Renderer, refer to its PyTorch version Install Run: pip install neural-rende

AgentMaker 13 Jul 12, 2022
Robust Self-augmentation for NER with Meta-reweighting

Robust Self-augmentation for NER with Meta-reweighting

Lam chi 17 Nov 22, 2022
This repository stores the code to reproduce the results published in "TiWS-iForest: Isolation Forest in Weakly Supervised and Tiny ML scenarios"

TinyWeaklyIsolationForest This repository stores the code to reproduce the results published in "TiWS-iForest: Isolation Forest in Weakly Supervised a

2 Mar 21, 2022
ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation

ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation This repository contains the source code of our paper, ESPNet (acc

Sachin Mehta 515 Dec 13, 2022