TabNet for fastai

Overview

TabNet for fastai

This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (>=2.0) library. The original paper https://arxiv.org/pdf/1908.07442.pdf.

Install

pip install fast_tabnet

How to use

model = TabNetModel(emb_szs, n_cont, out_sz, embed_p=0., y_range=None, n_d=8, n_a=8, n_steps=3, gamma=1.5, n_independent=2, n_shared=2, epsilon=1e-15, virtual_batch_size=128, momentum=0.02)

Parameters emb_szs, n_cont, out_sz, embed_p, y_range are the same as for fastai TabularModel.

  • n_d : int Dimension of the prediction layer (usually between 4 and 64)
  • n_a : int Dimension of the attention layer (usually between 4 and 64)
  • n_steps: int Number of sucessive steps in the newtork (usually betwenn 3 and 10)
  • gamma : float Float above 1, scaling factor for attention updates (usually betwenn 1.0 to 2.0)
  • momentum : float Float value between 0 and 1 which will be used for momentum in all batch norm
  • n_independent : int Number of independent GLU layer in each GLU block (default 2)
  • n_shared : int Number of independent GLU layer in each GLU block (default 2)
  • epsilon: float Avoid log(0), this should be kept very low

Example

Below is an example from fastai library, but the model in use is TabNet

from fastai.basics import *
from fastai.tabular.all import *
from fast_tabnet.core import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
df_main,df_test = df.iloc[:-1000].copy(),df.iloc[-1000:].copy()
df_main.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country salary
0 49 Private 101320 Assoc-acdm 12.0 Married-civ-spouse NaN Wife White Female 0 1902 40 United-States >=50k
1 44 Private 236746 Masters 14.0 Divorced Exec-managerial Not-in-family White Male 10520 0 45 United-States >=50k
2 38 Private 96185 HS-grad NaN Divorced NaN Unmarried Black Female 0 0 32 United-States <50k
3 38 Self-emp-inc 112847 Prof-school 15.0 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander Male 0 0 40 United-States >=50k
4 42 Self-emp-not-inc 82297 7th-8th NaN Married-civ-spouse Other-service Wife Black Female 0 0 50 United-States <50k
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 
             'relationship', 'race', 'native-country', 'sex']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter()(range_of(df_main))
to = TabularPandas(df_main, procs, cat_names, cont_names, y_names="salary", 
                   y_block = CategoryBlock(), splits=splits)
dls = to.dataloaders(bs=32)
dls.valid.show_batch()
workclass education marital-status occupation relationship race native-country sex education-num_na age fnlwgt education-num salary
0 Private HS-grad Married-civ-spouse Other-service Wife White United-States Female False 39.000000 196673.000115 9.0 <50k
1 Private HS-grad Married-civ-spouse Craft-repair Husband White United-States Male False 32.000000 198067.999771 9.0 <50k
2 State-gov HS-grad Never-married Adm-clerical Own-child White United-States Female False 18.999999 176633.999931 9.0 <50k
3 Private Some-college Married-civ-spouse Prof-specialty Husband White United-States Male False 67.999999 107626.998490 10.0 <50k
4 Private Masters Never-married Exec-managerial Not-in-family Black United-States Male False 29.000000 214925.000260 14.0 <50k
5 Private HS-grad Married-civ-spouse Priv-house-serv Wife White United-States Female False 22.000000 200109.000126 9.0 <50k
6 Private Some-college Never-married Sales Own-child White United-States Female False 18.000000 60980.998429 10.0 <50k
7 Private Some-college Separated Adm-clerical Not-in-family White United-States Female False 28.000000 334367.998199 10.0 <50k
8 Private 11th Married-civ-spouse Transport-moving Husband White United-States Male False 49.000000 123584.001097 7.0 <50k
9 Private Masters Never-married Prof-specialty Not-in-family White United-States Female False 26.000000 397316.999922 14.0 <50k
to_tst = to.new(df_test)
to_tst.process()
to_tst.all_cols.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
workclass education marital-status occupation relationship race native-country sex education-num_na age fnlwgt education-num salary
31561 5 2 5 9 3 3 40 2 1 -1.505833 -0.559418 -1.202170 0
31562 2 12 5 2 5 3 40 1 1 -1.432653 0.421241 -0.418032 0
31563 5 7 3 4 1 5 40 2 1 -0.115406 0.132868 -1.986307 0
31564 8 12 3 9 1 5 40 2 1 1.494561 0.749805 -0.418032 0
31565 1 12 1 1 5 3 40 2 1 -0.481308 7.529798 -0.418032 0
emb_szs = get_emb_sz(to)

That's the use of the model

model = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=8, n_a=8, n_steps=5, mask_type='entmax');
learn = Learner(dls, model, CrossEntropyLossFlat(), opt_func=Adam, lr=3e-2, metrics=[accuracy])
learn.lr_find()
SuggestedLRs(lr_min=0.2754228591918945, lr_steep=1.9054607491852948e-06)

png

learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.446274 0.414451 0.817015 00:30
1 0.380002 0.393030 0.818916 00:30
2 0.371149 0.359802 0.832066 00:30
3 0.349027 0.352255 0.835868 00:30
4 0.355339 0.349360 0.836819 00:30

Tabnet interpretability

# feature importance for 2k rows
dl = learn.dls.test_dl(df.iloc[:2000], bs=256)
feature_importances = tabnet_feature_importances(learn.model, dl)
# per sample interpretation
dl = learn.dls.test_dl(df.iloc[:20], bs=20)
res_explain, res_masks = tabnet_explain(learn.model, dl)
plt.xticks(rotation='vertical')
plt.bar(dl.x_names, feature_importances, color='g')
plt.show()

png

def plot_explain(masks, lbls, figsize=(12,12)):
    "Plots masks with `lbls` (`dls.x_names`)"
    fig = plt.figure(figsize=figsize)
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
    plt.yticks(np.arange(0, len(masks), 1.0))
    plt.xticks(np.arange(0, len(masks[0]), 1.0))
    ax.set_xticklabels(lbls, rotation=90)
    plt.ylabel('Sample Number')
    plt.xlabel('Variable')
    plt.imshow(masks)
plot_explain(res_explain, dl.x_names)

png

Hyperparameter search with Bayesian Optimization

If your dataset isn't huge you can tune hyperparameters for tabular models with Bayesian Optimization. You can optimize directly your metric using this approach if the metric is sensitive enough (in our example it is not and we use validation loss instead). Also, you should create the second validation set, because you will use the first as a training set for Bayesian Optimization.

You may need to install the optimizer pip install bayesian-optimization

from functools import lru_cache
# The function we'll optimize
@lru_cache(1000)
def get_accuracy(n_d:Int, n_a:Int, n_steps:Int):
    model = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=n_d, n_a=n_a, n_steps=n_steps, gamma=1.5)
    learn = Learner(dls, model, CrossEntropyLossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])
    learn.fit_one_cycle(5)
    return float(learn.validate(dl=learn.dls.valid)[1])

This implementation of Bayesian Optimization doesn't work naturally with descreet values. That's why we use wrapper with lru_cache.

def fit_accuracy(pow_n_d, pow_n_a, pow_n_steps):
    n_d, n_a, n_steps = map(lambda x: 2**int(x), (pow_n_d, pow_n_a, pow_n_steps))
    return get_accuracy(n_d, n_a, n_steps)
from bayes_opt import BayesianOptimization

# Bounded region of parameter space
pbounds = {'pow_n_d': (0, 8), 'pow_n_a': (0, 8), 'pow_n_steps': (0, 4)}

optimizer = BayesianOptimization(
    f=fit_accuracy,
    pbounds=pbounds,
)
optimizer.maximize(
    init_points=15,
    n_iter=100,
)
|   iter    |  target   |  pow_n_a  |  pow_n_d  | pow_n_... |
-------------------------------------------------------------
epoch train_loss valid_loss accuracy time
0 0.404888 0.432834 0.793885 00:10
1 0.367979 0.384840 0.818600 00:09
2 0.366444 0.372005 0.819708 00:09
3 0.362771 0.366949 0.823511 00:10
4 0.353682 0.367132 0.823511 00:10
| �[0m 1       �[0m | �[0m 0.8235  �[0m | �[0m 0.9408  �[0m | �[0m 1.898   �[0m | �[0m 1.652   �[0m |
epoch train_loss valid_loss accuracy time
0 0.393301 0.449742 0.810836 00:08
1 0.379140 0.413773 0.815589 00:07
2 0.355790 0.388907 0.822560 00:07
3 0.349984 0.362671 0.828739 00:07
4 0.348000 0.360150 0.827313 00:07
| �[95m 2       �[0m | �[95m 0.8273  �[0m | �[95m 4.262   �[0m | �[95m 5.604   �[0m | �[95m 0.2437  �[0m |
epoch train_loss valid_loss accuracy time
0 0.451572 0.434189 0.781210 00:12
1 0.423763 0.413420 0.805450 00:12
2 0.398922 0.408688 0.814164 00:12
3 0.390981 0.392398 0.808935 00:12
4 0.376418 0.382250 0.817174 00:12
| �[0m 3       �[0m | �[0m 0.8172  �[0m | �[0m 7.233   �[0m | �[0m 6.471   �[0m | �[0m 2.508   �[0m |
epoch train_loss valid_loss accuracy time
0 0.403187 0.413986 0.798162 00:07
1 0.398544 0.390102 0.820184 00:07
2 0.390569 0.389703 0.825253 00:07
3 0.375426 0.385706 0.826996 00:07
4 0.370446 0.383366 0.831115 00:06
| �[95m 4       �[0m | �[95m 0.8311  �[0m | �[95m 5.935   �[0m | �[95m 1.241   �[0m | �[95m 0.3809  �[0m |
epoch train_loss valid_loss accuracy time
0 0.464145 0.458641 0.751267 00:18
1 0.424691 0.436968 0.788023 00:18
2 0.431576 0.436581 0.775824 00:18
3 0.432143 0.437062 0.759506 00:18
4 0.429915 0.438332 0.758555 00:18
| �[0m 5       �[0m | �[0m 0.7586  �[0m | �[0m 2.554   �[0m | �[0m 0.4992  �[0m | �[0m 3.111   �[0m |
epoch train_loss valid_loss accuracy time
0 0.470359 0.475826 0.748891 00:12
1 0.411564 0.409433 0.797053 00:12
2 0.392718 0.397363 0.809727 00:12
3 0.387564 0.380033 0.814322 00:12
4 0.374153 0.378258 0.818916 00:12
| �[0m 6       �[0m | �[0m 0.8189  �[0m | �[0m 4.592   �[0m | �[0m 2.138   �[0m | �[0m 2.824   �[0m |
epoch train_loss valid_loss accuracy time
0 0.547042 0.588752 0.754119 00:18
1 0.491731 0.469795 0.771863 00:18
2 0.454340 0.433961 0.775190 00:18
3 0.424386 0.432385 0.782953 00:18
4 0.397645 0.406420 0.805767 00:19
| �[0m 7       �[0m | �[0m 0.8058  �[0m | �[0m 6.186   �[0m | �[0m 7.016   �[0m | �[0m 3.316   �[0m |
epoch train_loss valid_loss accuracy time
0 0.485245 0.487635 0.751109 00:18
1 0.450832 0.446423 0.750317 00:18
2 0.448203 0.449419 0.755228 00:18
3 0.430258 0.443562 0.744297 00:18
4 0.429821 0.437173 0.761565 00:18
| �[0m 8       �[0m | �[0m 0.7616  �[0m | �[0m 2.018   �[0m | �[0m 1.316   �[0m | �[0m 3.675   �[0m |
epoch train_loss valid_loss accuracy time
0 0.458425 0.455733 0.751584 00:12
1 0.439781 0.467807 0.751109 00:12
2 0.420331 0.432216 0.775190 00:12
3 0.421012 0.421412 0.782319 00:12
4 0.401828 0.413434 0.801014 00:12
| �[0m 9       �[0m | �[0m 0.801   �[0m | �[0m 2.051   �[0m | �[0m 1.958   �[0m | �[0m 2.332   �[0m |
epoch train_loss valid_loss accuracy time
0 0.546997 0.506728 0.761407 00:18
1 0.489712 0.439324 0.799588 00:18
2 0.448558 0.448419 0.786122 00:18
3 0.436869 0.435375 0.801648 00:18
4 0.417128 0.421093 0.798321 00:18
| �[0m 10      �[0m | �[0m 0.7983  �[0m | �[0m 5.203   �[0m | �[0m 7.719   �[0m | �[0m 3.407   �[0m |
epoch train_loss valid_loss accuracy time
0 0.380781 0.463409 0.786439 00:07
1 0.359212 0.461147 0.798321 00:07
2 0.351414 0.368950 0.822719 00:07
3 0.347257 0.367056 0.829373 00:07
4 0.337212 0.362375 0.830799 00:07
| �[0m 11      �[0m | �[0m 0.8308  �[0m | �[0m 6.048   �[0m | �[0m 4.376   �[0m | �[0m 0.08141 �[0m |
epoch train_loss valid_loss accuracy time
0 0.430772 0.430897 0.767744 00:12
1 0.402611 0.432137 0.764259 00:12
2 0.407579 0.409651 0.812104 00:12
3 0.374988 0.391822 0.816698 00:12
4 0.378011 0.389278 0.816065 00:12
| �[0m 12      �[0m | �[0m 0.8161  �[0m | �[0m 7.083   �[0m | �[0m 1.385   �[0m | �[0m 2.806   �[0m |
epoch train_loss valid_loss accuracy time
0 0.402018 0.412051 0.812262 00:09
1 0.372804 0.464937 0.811629 00:09
2 0.368274 0.384675 0.820184 00:09
3 0.364502 0.371920 0.820659 00:09
4 0.348998 0.369445 0.823828 00:09
| �[0m 13      �[0m | �[0m 0.8238  �[0m | �[0m 4.812   �[0m | �[0m 3.785   �[0m | �[0m 1.396   �[0m |
| �[0m 14      �[0m | �[0m 0.8172  �[0m | �[0m 7.672   �[0m | �[0m 6.719   �[0m | �[0m 2.72    �[0m |
epoch train_loss valid_loss accuracy time
0 0.476033 0.442598 0.803549 00:12
1 0.405236 0.414015 0.788973 00:11
2 0.406291 0.451269 0.789449 00:11
3 0.391013 0.393243 0.816065 00:12
4 0.374160 0.377635 0.821451 00:12
| �[0m 15      �[0m | �[0m 0.8215  �[0m | �[0m 6.464   �[0m | �[0m 7.954   �[0m | �[0m 2.647   �[0m |
epoch train_loss valid_loss accuracy time
0 0.390142 0.390678 0.810995 00:06
1 0.381717 0.382202 0.813055 00:06
2 0.368564 0.378705 0.823828 00:06
3 0.358858 0.368329 0.823511 00:07
4 0.353392 0.363913 0.825887 00:06
| �[0m 16      �[0m | �[0m 0.8259  �[0m | �[0m 0.1229  �[0m | �[0m 7.83    �[0m | �[0m 0.3708  �[0m |
epoch train_loss valid_loss accuracy time
0 0.381215 0.422651 0.800697 00:06
1 0.377345 0.380863 0.815906 00:06
2 0.366631 0.370579 0.822877 00:06
3 0.362745 0.366619 0.823352 00:07
4 0.356861 0.364835 0.825887 00:07
| �[0m 17      �[0m | �[0m 0.8259  �[0m | �[0m 0.03098 �[0m | �[0m 3.326   �[0m | �[0m 0.007025�[0m |
epoch train_loss valid_loss accuracy time
0 0.404604 0.443035 0.824461 00:07
1 0.361872 0.388880 0.823669 00:06
2 0.375164 0.369968 0.825095 00:06
3 0.352091 0.363823 0.827947 00:06
4 0.335458 0.362544 0.829373 00:07
| �[0m 18      �[0m | �[0m 0.8294  �[0m | �[0m 7.81    �[0m | �[0m 7.976   �[0m | �[0m 0.0194  �[0m |
epoch train_loss valid_loss accuracy time
0 0.679292 0.677299 0.248891 00:05
1 0.675403 0.678406 0.248891 00:05
2 0.673259 0.673374 0.248891 00:06
3 0.674996 0.673514 0.248891 00:07
4 0.668813 0.673671 0.248891 00:07
| �[0m 19      �[0m | �[0m 0.2489  �[0m | �[0m 0.4499  �[0m | �[0m 0.138   �[0m | �[0m 0.001101�[0m |
epoch train_loss valid_loss accuracy time
0 0.524201 0.528132 0.729880 00:30
1 0.419377 0.403198 0.812104 00:31
2 0.399398 0.418890 0.812421 00:31
3 0.381651 0.391744 0.819075 00:31
4 0.368742 0.377904 0.822085 00:31
| �[0m 20      �[0m | �[0m 0.8221  �[0m | �[0m 0.0     �[0m | �[0m 6.575   �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.681083 0.682397 0.248891 00:05
1 0.672935 0.679371 0.248891 00:06
2 0.675200 0.673466 0.248891 00:06
3 0.674251 0.673356 0.248891 00:06
4 0.668861 0.673186 0.248891 00:06
| �[0m 21      �[0m | �[0m 0.2489  �[0m | �[0m 8.0     �[0m | �[0m 0.0     �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.407246 0.432203 0.801331 00:10
1 0.385086 0.399513 0.811312 00:11
2 0.377365 0.384121 0.816065 00:12
3 0.366855 0.371010 0.823194 00:12
4 0.361931 0.368933 0.825095 00:12
| �[0m 22      �[0m | �[0m 0.8251  �[0m | �[0m 0.0     �[0m | �[0m 4.502   �[0m | �[0m 2.193   �[0m |
epoch train_loss valid_loss accuracy time
0 0.493623 0.476921 0.766001 00:30
1 0.441126 0.443774 0.776774 00:31
2 0.424523 0.437125 0.783904 00:31
3 0.402457 0.408628 0.795944 00:31
4 0.439420 0.431756 0.788973 00:32
| �[0m 23      �[0m | �[0m 0.789   �[0m | �[0m 8.0     �[0m | �[0m 3.702   �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.515615 0.513919 0.751109 00:31
1 0.462674 0.495322 0.751584 00:31
2 0.465430 0.483685 0.751267 00:31
3 0.481308 0.495375 0.755070 00:31
4 0.481324 0.491275 0.754911 00:32
| �[0m 24      �[0m | �[0m 0.7549  �[0m | �[0m 6.009   �[0m | �[0m 0.0     �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.422837 0.403953 0.819392 00:06
1 0.380753 0.367345 0.826838 00:06
2 0.353045 0.365174 0.830006 00:07
3 0.348628 0.364282 0.826362 00:07
4 0.343561 0.361509 0.829214 00:07
| �[0m 25      �[0m | �[0m 0.8292  �[0m | �[0m 3.522   �[0m | �[0m 8.0     �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.807766 1.307279 0.481622 00:31
1 0.513308 0.499470 0.783587 00:32
2 0.445906 0.492620 0.798004 00:31
3 0.385094 0.399986 0.807509 00:32
4 0.387228 0.384739 0.817015 00:31
| �[0m 26      �[0m | �[0m 0.817   �[0m | �[0m 0.0     �[0m | �[0m 8.0     �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.442076 0.491338 0.755387 00:31
1 0.441078 0.443674 0.760773 00:31
2 0.417575 0.418758 0.792142 00:31
3 0.410825 0.417581 0.788498 00:34
4 0.403407 0.410941 0.798321 00:46
| �[0m 27      �[0m | �[0m 0.7983  �[0m | �[0m 0.0     �[0m | �[0m 0.0     �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.407006 0.419679 0.792142 00:08
1 0.390913 0.392631 0.810520 00:08
2 0.365560 0.394330 0.817491 00:08
3 0.378459 0.387244 0.820659 00:08
4 0.375275 0.385417 0.828897 00:08
| �[0m 28      �[0m | �[0m 0.8289  �[0m | �[0m 3.379   �[0m | �[0m 2.848   �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.430604 0.469592 0.781210 00:45
1 0.423074 0.429704 0.797529 00:45
2 0.400120 0.393398 0.810995 00:45
3 0.382361 0.390651 0.816065 00:46
4 0.389520 0.401878 0.807193 00:46
| �[0m 29      �[0m | �[0m 0.8072  �[0m | �[0m 0.0     �[0m | �[0m 2.588   �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.396348 0.397454 0.806717 00:08
1 0.383342 0.386023 0.819550 00:07
2 0.369493 0.374401 0.820025 00:07
3 0.356015 0.366535 0.826204 00:08
4 0.341073 0.365241 0.826204 00:08
| �[0m 30      �[0m | �[0m 0.8262  �[0m | �[0m 1.217   �[0m | �[0m 5.622   �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.809077 0.480591 0.782795 00:45
1 0.571318 0.497731 0.739068 00:45
2 0.514562 0.461726 0.781527 00:45
3 0.439822 0.451722 0.787231 00:44
4 0.419881 0.422125 0.801648 00:45
| �[0m 31      �[0m | �[0m 0.8016  �[0m | �[0m 8.0     �[0m | �[0m 8.0     �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.410521 0.435045 0.810044 00:08
1 0.363098 0.378001 0.821926 00:08
2 0.359525 0.364477 0.827788 00:08
3 0.354005 0.366507 0.821610 00:08
4 0.347293 0.362657 0.829373 00:08
| �[0m 32      �[0m | �[0m 0.8294  �[0m | �[0m 5.864   �[0m | �[0m 8.0     �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.498376 0.436696 0.794043 00:16
1 0.411699 0.435537 0.801331 00:16
2 0.385327 0.396916 0.820184 00:16
3 0.382020 0.389856 0.813371 00:16
4 0.373869 0.377804 0.820817 00:15
| �[0m 33      �[0m | �[0m 0.8208  �[0m | �[0m 1.776   �[0m | �[0m 8.0     �[0m | �[0m 2.212   �[0m |
epoch train_loss valid_loss accuracy time
0 0.404653 0.440106 0.772180 00:11
1 0.377931 0.393715 0.817332 00:11
2 0.373221 0.379273 0.826838 00:11
3 0.359682 0.362844 0.828422 00:11
4 0.340384 0.363072 0.828897 00:11
| �[0m 34      �[0m | �[0m 0.8289  �[0m | �[0m 5.777   �[0m | �[0m 2.2     �[0m | �[0m 1.31    �[0m |
epoch train_loss valid_loss accuracy time
0 0.520308 0.503207 0.749208 00:45
1 0.472501 0.451469 0.780418 00:45
2 0.454686 0.429175 0.784854 00:45
3 0.400800 0.413727 0.795469 00:44
4 0.405604 0.409770 0.801648 00:45
| �[0m 35      �[0m | �[0m 0.8016  �[0m | �[0m 2.748   �[0m | �[0m 5.915   �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.504537 0.501541 0.750317 00:45
1 0.465937 0.477715 0.773289 00:45
2 0.435364 0.481415 0.766635 00:45
3 0.425434 0.442198 0.772814 00:45
4 0.425779 0.458947 0.771863 00:45
| �[0m 36      �[0m | �[0m 0.7719  �[0m | �[0m 6.251   �[0m | �[0m 2.532   �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.420782 0.420721 0.791350 00:10
1 0.403576 0.408376 0.800222 00:10
2 0.390236 0.393624 0.820342 00:11
3 0.377777 0.389657 0.821610 00:11
4 0.382809 0.386011 0.820976 00:11
| �[0m 37      �[0m | �[0m 0.821   �[0m | �[0m 5.093   �[0m | �[0m 0.172   �[0m | �[0m 1.64    �[0m |
epoch train_loss valid_loss accuracy time
0 0.393575 0.397811 0.812262 00:08
1 0.378272 0.381915 0.815748 00:08
2 0.364799 0.369214 0.824620 00:08
3 0.355757 0.364554 0.826996 00:08
4 0.342090 0.362723 0.824303 00:08
| �[0m 38      �[0m | �[0m 0.8243  �[0m | �[0m 8.0     �[0m | �[0m 5.799   �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.393693 0.396980 0.822085 00:11
1 0.361231 0.393146 0.813847 00:11
2 0.345645 0.379510 0.823986 00:11
3 0.349778 0.367077 0.826679 00:11
4 0.342390 0.362027 0.827788 00:11
| �[0m 39      �[0m | �[0m 0.8278  �[0m | �[0m 1.62    �[0m | �[0m 3.832   �[0m | �[0m 1.151   �[0m |
epoch train_loss valid_loss accuracy time
0 0.832737 0.491002 0.771546 00:43
1 0.627948 0.553552 0.764734 00:43
2 0.498901 0.467162 0.791984 00:46
3 0.431196 0.444576 0.785646 00:43
4 0.399745 0.427060 0.796578 00:42
| �[0m 40      �[0m | �[0m 0.7966  �[0m | �[0m 2.198   �[0m | �[0m 8.0     �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.511301 0.514401 0.751267 00:43
1 0.447332 0.445157 0.751109 00:43
2 0.451125 0.438327 0.750951 00:42
3 0.445883 0.443266 0.751267 00:42
4 0.444816 0.438459 0.764100 00:42
| �[0m 41      �[0m | �[0m 0.7641  �[0m | �[0m 8.0     �[0m | �[0m 1.03    �[0m | �[0m 4.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.408504 0.413275 0.797212 00:15
1 0.392707 0.399085 0.805767 00:15
2 0.379938 0.395550 0.817807 00:15
3 0.375288 0.383186 0.820817 00:15
4 0.360417 0.375098 0.823194 00:16
| �[0m 42      �[0m | �[0m 0.8232  �[0m | �[0m 0.0     �[0m | �[0m 2.504   �[0m | �[0m 2.135   �[0m |
epoch train_loss valid_loss accuracy time
0 0.399371 0.415196 0.801014 00:07
1 0.367804 0.392020 0.810995 00:06
2 0.362288 0.385124 0.820659 00:07
3 0.344728 0.371339 0.823669 00:07
4 0.345769 0.362059 0.829373 00:07
| �[0m 43      �[0m | �[0m 0.8294  �[0m | �[0m 0.0     �[0m | �[0m 5.441   �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.397157 0.431003 0.803866 00:06
1 0.394964 0.396448 0.810361 00:06
2 0.378584 0.387943 0.820659 00:07
3 0.371601 0.386186 0.818283 00:07
4 0.369759 0.384339 0.827630 00:07
| �[0m 44      �[0m | �[0m 0.8276  �[0m | �[0m 4.636   �[0m | �[0m 1.476   �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.408654 0.426806 0.791191 00:12
1 0.394184 0.406586 0.786439 00:12
2 0.369625 0.372680 0.822560 00:12
3 0.349444 0.368142 0.823828 00:12
4 0.351684 0.363406 0.826204 00:12
| �[0m 45      �[0m | �[0m 0.8262  �[0m | �[0m 0.0     �[0m | �[0m 7.071   �[0m | �[0m 2.071   �[0m |
epoch train_loss valid_loss accuracy time
0 0.400293 0.416098 0.811629 00:08
1 0.377387 0.433395 0.807034 00:08
2 0.368131 0.395448 0.796420 00:08
3 0.367750 0.376879 0.817174 00:08
4 0.362124 0.371432 0.821134 00:08
| �[0m 46      �[0m | �[0m 0.8211  �[0m | �[0m 4.26    �[0m | �[0m 6.934   �[0m | �[0m 1.79    �[0m |
epoch train_loss valid_loss accuracy time
0 0.404579 0.437443 0.814797 00:07
1 0.375342 0.380416 0.824937 00:07
2 0.365835 0.377617 0.812738 00:07
3 0.354619 0.364503 0.827471 00:07
4 0.340603 0.363488 0.827947 00:07
| �[0m 47      �[0m | �[0m 0.8279  �[0m | �[0m 6.579   �[0m | �[0m 6.485   �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.384890 0.440342 0.812579 00:08
1 0.371483 0.387200 0.813847 00:09
2 0.365951 0.378071 0.818283 00:09
3 0.362965 0.369994 0.821610 00:09
4 0.356483 0.365151 0.826521 00:09
| �[0m 48      �[0m | �[0m 0.8265  �[0m | �[0m 8.0     �[0m | �[0m 4.293   �[0m | �[0m 1.74    �[0m |
epoch train_loss valid_loss accuracy time
0 0.386308 0.389250 0.815431 00:08
1 0.368402 0.389338 0.814956 00:09
2 0.362211 0.377196 0.824778 00:09
3 0.356135 0.362951 0.829531 00:09
4 0.341577 0.362476 0.830799 00:09
| �[0m 49      �[0m | �[0m 0.8308  �[0m | �[0m 7.909   �[0m | �[0m 7.827   �[0m | �[0m 1.323   �[0m |
epoch train_loss valid_loss accuracy time
0 0.426044 0.422882 0.791984 00:08
1 0.375756 0.381810 0.817491 00:09
2 0.363932 0.375904 0.818916 00:09
3 0.349442 0.365052 0.823986 00:09
4 0.344509 0.363027 0.830323 00:09
| �[0m 50      �[0m | �[0m 0.8303  �[0m | �[0m 4.946   �[0m | �[0m 1.246   �[0m | �[0m 1.589   �[0m |
epoch train_loss valid_loss accuracy time
0 0.388522 0.431909 0.820976 00:07
1 0.372532 0.448644 0.751109 00:07
2 0.358823 0.373322 0.823669 00:07
3 0.352838 0.362424 0.831591 00:07
4 0.352949 0.361356 0.831432 00:07
| �[95m 51      �[0m | �[95m 0.8314  �[0m | �[95m 5.664   �[0m | �[95m 2.626   �[0m | �[95m 0.003048�[0m |
epoch train_loss valid_loss accuracy time
0 0.389195 0.390032 0.817332 00:06
1 0.369993 0.382199 0.819708 00:07
2 0.362801 0.373282 0.826521 00:06
3 0.359760 0.363597 0.824303 00:06
4 0.344525 0.362097 0.828897 00:07
| �[0m 52      �[0m | �[0m 0.8289  �[0m | �[0m 1.287   �[0m | �[0m 3.505   �[0m | �[0m 0.06804 �[0m |
epoch train_loss valid_loss accuracy time
0 0.394940 0.403165 0.814639 00:06
1 0.371518 0.452118 0.806876 00:06
2 0.364734 0.377214 0.824461 00:06
3 0.347968 0.365335 0.823511 00:07
4 0.345476 0.363670 0.827155 00:07
| �[0m 53      �[0m | �[0m 0.8272  �[0m | �[0m 1.606   �[0m | �[0m 7.998   �[0m | �[0m 0.2009  �[0m |
epoch train_loss valid_loss accuracy time
0 0.500322 0.519261 0.757129 00:10
1 0.413270 0.423630 0.801965 00:11
2 0.380234 0.395588 0.813371 00:12
3 0.361677 0.378123 0.817174 00:12
4 0.374629 0.373772 0.820025 00:12
| �[0m 54      �[0m | �[0m 0.82    �[0m | �[0m 4.579   �[0m | �[0m 5.017   �[0m | �[0m 2.928   �[0m |
| �[0m 55      �[0m | �[0m 0.8259  �[0m | �[0m 0.02565 �[0m | �[0m 3.699   �[0m | �[0m 0.9808  �[0m |
epoch train_loss valid_loss accuracy time
0 0.452787 0.443697 0.768695 00:11
1 0.428332 0.415454 0.800697 00:11
2 0.396522 0.402850 0.807668 00:12
3 0.424802 0.414648 0.783587 00:12
4 0.385055 0.392359 0.801489 00:12
| �[0m 56      �[0m | �[0m 0.8015  �[0m | �[0m 1.927   �[0m | �[0m 5.92    �[0m | �[0m 2.53    �[0m |
epoch train_loss valid_loss accuracy time
0 0.435597 0.438222 0.810836 00:19
1 0.399920 0.531189 0.770754 00:19
2 0.403408 0.409382 0.804816 00:18
3 0.363519 0.383823 0.815906 00:19
4 0.360030 0.377621 0.819708 00:19
| �[0m 57      �[0m | �[0m 0.8197  �[0m | �[0m 0.7796  �[0m | �[0m 4.576   �[0m | �[0m 3.952   �[0m |
epoch train_loss valid_loss accuracy time
0 0.388445 0.420243 0.800539 00:07
1 0.372912 0.369659 0.827630 00:07
2 0.354443 0.366757 0.828105 00:07
3 0.352468 0.366038 0.822560 00:07
4 0.347822 0.362001 0.829690 00:07
| �[0m 58      �[0m | �[0m 0.8297  �[0m | �[0m 3.525   �[0m | �[0m 4.198   �[0m | �[0m 0.02314 �[0m |
epoch train_loss valid_loss accuracy time
0 0.444627 0.431739 0.787072 00:12
1 0.392637 0.412985 0.799747 00:12
2 0.369733 0.396133 0.802440 00:12
3 0.365821 0.373095 0.820342 00:12
4 0.370486 0.371560 0.819392 00:12
| �[0m 59      �[0m | �[0m 0.8194  �[0m | �[0m 6.711   �[0m | �[0m 3.848   �[0m | �[0m 2.395   �[0m |
epoch train_loss valid_loss accuracy time
0 0.396831 0.389045 0.809569 00:07
1 0.371171 0.375065 0.818600 00:07
2 0.350309 0.371795 0.824620 00:07
3 0.359700 0.363041 0.828739 00:07
4 0.345735 0.361556 0.830799 00:07
| �[0m 60      �[0m | �[0m 0.8308  �[0m | �[0m 4.914   �[0m | �[0m 7.944   �[0m | �[0m 0.9998  �[0m |
epoch train_loss valid_loss accuracy time
0 0.422853 0.412691 0.804341 00:09
1 0.375209 0.394692 0.817174 00:09
2 0.365574 0.380376 0.820184 00:08
3 0.359143 0.363607 0.831115 00:08
4 0.347991 0.361650 0.827947 00:08
| �[0m 61      �[0m | �[0m 0.8279  �[0m | �[0m 7.962   �[0m | �[0m 6.151   �[0m | �[0m 1.119   �[0m |
epoch train_loss valid_loss accuracy time
0 0.388147 0.405885 0.810678 00:07
1 0.367743 0.391867 0.807826 00:07
2 0.366964 0.362980 0.828739 00:07
3 0.363402 0.363396 0.829531 00:07
4 0.351094 0.362245 0.829214 00:07
| �[0m 62      �[0m | �[0m 0.8292  �[0m | �[0m 2.583   �[0m | �[0m 6.996   �[0m | �[0m 0.008348�[0m |
epoch train_loss valid_loss accuracy time
0 0.432724 0.390516 0.815114 00:11
1 0.407100 0.401564 0.811153 00:11
2 0.359414 0.384463 0.820342 00:11
3 0.358061 0.371844 0.826362 00:12
4 0.345357 0.362986 0.831115 00:12
| �[0m 63      �[0m | �[0m 0.8311  �[0m | �[0m 0.2249  �[0m | �[0m 8.0     �[0m | �[0m 2.823   �[0m |
epoch train_loss valid_loss accuracy time
0 0.418990 0.420463 0.808618 00:07
1 0.389830 0.398110 0.816223 00:08
2 0.382975 0.387620 0.814956 00:06
3 0.384093 0.379607 0.819392 00:06
4 0.358019 0.371140 0.823828 00:08
| �[0m 64      �[0m | �[0m 0.8238  �[0m | �[0m 5.764   �[0m | �[0m 5.509   �[0m | �[0m 1.482   �[0m |
epoch train_loss valid_loss accuracy time
0 0.385147 0.399680 0.806242 00:09
1 0.376032 0.381131 0.822560 00:09
2 0.363870 0.378227 0.822402 00:09
3 0.351089 0.368790 0.826838 00:09
4 0.340404 0.361807 0.829214 00:09
| �[0m 65      �[0m | �[0m 0.8292  �[0m | �[0m 1.048   �[0m | �[0m 2.939   �[0m | �[0m 1.922   �[0m |
epoch train_loss valid_loss accuracy time
0 0.526628 0.507236 0.746673 00:17
1 0.460229 0.455675 0.765684 00:19
2 0.417427 0.421368 0.785963 00:19
3 0.462800 0.458844 0.773923 00:19
4 0.449479 0.456627 0.783587 00:19
| �[0m 66      �[0m | �[0m 0.7836  �[0m | �[0m 3.68    �[0m | �[0m 3.977   �[0m | �[0m 3.919   �[0m |
epoch train_loss valid_loss accuracy time
0 0.621678 0.559080 0.749049 00:10
1 0.457104 0.473610 0.758397 00:11
2 0.416287 0.416622 0.764575 00:13
3 0.388107 0.403844 0.811945 00:13
4 0.384231 0.396397 0.813055 00:13
| �[0m 67      �[0m | �[0m 0.8131  �[0m | �[0m 5.907   �[0m | �[0m 0.9452  �[0m | �[0m 2.168   �[0m |
epoch train_loss valid_loss accuracy time
0 0.410105 0.416171 0.808618 00:11
1 0.381669 0.400109 0.809094 00:13
2 0.377539 0.403879 0.803074 00:13
3 0.374653 0.389122 0.808618 00:13
4 0.366356 0.380526 0.814005 00:13
| �[0m 68      �[0m | �[0m 0.814   �[0m | �[0m 7.981   �[0m | �[0m 2.796   �[0m | �[0m 2.78    �[0m |
epoch train_loss valid_loss accuracy time
0 0.380539 0.420856 0.815589 00:07
1 0.368232 0.424276 0.803866 00:05
2 0.358194 0.378501 0.827155 00:06
3 0.353427 0.362224 0.829690 00:07
4 0.344443 0.361554 0.826838 00:07
| �[0m 69      �[0m | �[0m 0.8268  �[0m | �[0m 7.223   �[0m | �[0m 3.762   �[0m | �[0m 0.6961  �[0m |
epoch train_loss valid_loss accuracy time
0 0.445924 0.488581 0.752693 00:17
1 0.410709 0.400962 0.813688 00:18
2 0.373518 0.393235 0.820184 00:18
3 0.364160 0.378920 0.820817 00:17
4 0.357551 0.371629 0.825412 00:17
| �[0m 70      �[0m | �[0m 0.8254  �[0m | �[0m 0.009375�[0m | �[0m 5.081   �[0m | �[0m 3.79    �[0m |
epoch train_loss valid_loss accuracy time
0 0.410323 0.442670 0.814797 00:08
1 0.379279 0.405500 0.807034 00:08
2 0.387576 0.392448 0.819708 00:09
3 0.371622 0.389167 0.823035 00:09
4 0.374182 0.386964 0.825095 00:09
| �[0m 71      �[0m | �[0m 0.8251  �[0m | �[0m 3.293   �[0m | �[0m 2.76    �[0m | �[0m 1.061   �[0m |
| �[0m 72      �[0m | �[0m 0.8308  �[0m | �[0m 4.589   �[0m | �[0m 7.13    �[0m | �[0m 0.003179�[0m |
epoch train_loss valid_loss accuracy time
0 0.391509 0.399966 0.806400 00:06
1 0.366694 0.405719 0.823828 00:07
2 0.359751 0.375496 0.822877 00:07
3 0.347678 0.361711 0.830799 00:07
4 0.336896 0.361922 0.828580 00:07
| �[0m 73      �[0m | �[0m 0.8286  �[0m | �[0m 7.118   �[0m | �[0m 5.204   �[0m | �[0m 0.5939  �[0m |
epoch train_loss valid_loss accuracy time
0 0.435485 0.414883 0.808143 00:08
1 0.373591 0.417138 0.814005 00:09
2 0.369590 0.375724 0.820184 00:09
3 0.370829 0.368655 0.829531 00:09
4 0.346463 0.366307 0.825412 00:09
| �[0m 74      �[0m | �[0m 0.8254  �[0m | �[0m 6.837   �[0m | �[0m 7.988   �[0m | �[0m 1.055   �[0m |
| �[0m 75      �[0m | �[0m 0.8259  �[0m | �[0m 0.6629  �[0m | �[0m 7.012   �[0m | �[0m 0.03222 �[0m |
| �[0m 76      �[0m | �[0m 0.8311  �[0m | �[0m 5.177   �[0m | �[0m 1.457   �[0m | �[0m 0.5857  �[0m |
epoch train_loss valid_loss accuracy time
0 0.455741 0.530008 0.794994 00:11
1 0.421961 0.423317 0.805292 00:11
2 0.405799 0.405729 0.807351 00:12
3 0.383895 0.395092 0.816857 00:12
4 0.378882 0.386044 0.818758 00:12
| �[0m 77      �[0m | �[0m 0.8188  �[0m | �[0m 3.308   �[0m | �[0m 4.533   �[0m | �[0m 2.048   �[0m |
epoch train_loss valid_loss accuracy time
0 0.412261 0.410950 0.798954 00:08
1 0.382314 0.408471 0.803390 00:09
2 0.347488 0.387500 0.815273 00:09
3 0.343408 0.372050 0.821451 00:09
4 0.344963 0.366158 0.822719 00:09
| �[0m 78      �[0m | �[0m 0.8227  �[0m | �[0m 0.4036  �[0m | �[0m 7.997   �[0m | �[0m 1.439   �[0m |
epoch train_loss valid_loss accuracy time
0 0.412856 0.410016 0.799113 00:08
1 0.410852 0.416405 0.788498 00:08
2 0.373897 0.384385 0.824303 00:09
3 0.353164 0.366129 0.822719 00:09
4 0.353253 0.362269 0.826362 00:09
| �[0m 79      �[0m | �[0m 0.8264  �[0m | �[0m 3.438   �[0m | �[0m 7.982   �[0m | �[0m 1.829   �[0m |
epoch train_loss valid_loss accuracy time
0 0.419316 0.408936 0.798162 00:08
1 0.393826 0.390526 0.820184 00:08
2 0.372879 0.374823 0.822719 00:08
3 0.358019 0.370913 0.820342 00:08
4 0.346020 0.362252 0.829690 00:08
| �[0m 80      �[0m | �[0m 0.8297  �[0m | �[0m 6.88    �[0m | �[0m 2.404   �[0m | �[0m 1.666   �[0m |
epoch train_loss valid_loss accuracy time
0 0.433481 0.437320 0.790082 00:18
1 0.415280 0.402946 0.814164 00:18
2 0.365575 0.376285 0.822877 00:18
3 0.363206 0.371865 0.820501 00:18
4 0.356401 0.370252 0.823828 00:18
| �[0m 81      �[0m | �[0m 0.8238  �[0m | �[0m 0.03221 �[0m | �[0m 1.306   �[0m | �[0m 3.909   �[0m |
epoch train_loss valid_loss accuracy time
0 0.393150 0.420964 0.783745 00:07
1 0.375065 0.371380 0.823986 00:07
2 0.362952 0.387037 0.813688 00:07
3 0.347245 0.370225 0.824937 00:07
4 0.348406 0.361420 0.830640 00:07
| �[0m 82      �[0m | �[0m 0.8306  �[0m | �[0m 1.575   �[0m | �[0m 2.689   �[0m | �[0m 0.8684  �[0m |
epoch train_loss valid_loss accuracy time
0 0.395530 0.397430 0.818600 00:06
1 0.358679 0.396773 0.818283 00:07
2 0.349305 0.372877 0.823828 00:07
3 0.347346 0.363006 0.828422 00:07
4 0.335652 0.362567 0.830957 00:07
| �[0m 83      �[0m | �[0m 0.831   �[0m | �[0m 2.765   �[0m | �[0m 5.439   �[0m | �[0m 0.04047 �[0m |
epoch train_loss valid_loss accuracy time
0 0.380831 0.386072 0.814322 00:05
1 0.369778 0.392521 0.810520 00:07
2 0.368286 0.383131 0.816857 00:07
3 0.356585 0.367839 0.821768 00:07
4 0.344722 0.366639 0.825253 00:07
| �[0m 84      �[0m | �[0m 0.8253  �[0m | �[0m 0.1961  �[0m | �[0m 4.123   �[0m | �[0m 0.02039 �[0m |
epoch train_loss valid_loss accuracy time
0 0.502605 0.463703 0.772180 00:11
1 0.407231 0.404386 0.804499 00:13
2 0.406849 0.411254 0.817491 00:12
3 0.379910 0.389118 0.817174 00:12
4 0.370964 0.379835 0.821293 00:12
| �[0m 85      �[0m | �[0m 0.8213  �[0m | �[0m 7.937   �[0m | �[0m 7.939   �[0m | �[0m 2.895   �[0m |
epoch train_loss valid_loss accuracy time
0 0.402574 0.431237 0.786755 00:11
1 0.380494 0.392253 0.817966 00:11
2 0.363284 0.393907 0.815748 00:11
3 0.355680 0.368488 0.822560 00:12
4 0.362978 0.367755 0.823511 00:12
| �[0m 86      �[0m | �[0m 0.8235  �[0m | �[0m 0.06921 �[0m | �[0m 5.7     �[0m | �[0m 2.778   �[0m |
epoch train_loss valid_loss accuracy time
0 0.448572 0.451255 0.789766 00:10
1 0.417093 0.411838 0.808143 00:11
2 0.400799 0.400185 0.816223 00:12
3 0.378641 0.385082 0.820342 00:12
4 0.365278 0.380320 0.818441 00:11
| �[0m 87      �[0m | �[0m 0.8184  �[0m | �[0m 7.965   �[0m | �[0m 5.261   �[0m | �[0m 2.661   �[0m |
epoch train_loss valid_loss accuracy time
0 0.399453 0.429288 0.783112 00:07
1 0.368987 0.376985 0.825729 00:07
2 0.358226 0.372103 0.830165 00:07
3 0.348940 0.362069 0.831115 00:07
4 0.341437 0.361470 0.830482 00:07
| �[0m 88      �[0m | �[0m 0.8305  �[0m | �[0m 2.792   �[0m | �[0m 7.917   �[0m | �[0m 0.761   �[0m |
| �[0m 89      �[0m | �[0m 0.8294  �[0m | �[0m 7.995   �[0m | �[0m 7.186   �[0m | �[0m 0.1199  �[0m |
epoch train_loss valid_loss accuracy time
0 0.390914 0.403780 0.799905 00:07
1 0.375241 0.400406 0.821610 00:07
2 0.359710 0.373233 0.826046 00:07
3 0.351108 0.367255 0.823986 00:07
4 0.348230 0.362827 0.830482 00:07
| �[0m 90      �[0m | �[0m 0.8305  �[0m | �[0m 2.526   �[0m | �[0m 3.741   �[0m | �[0m 0.1186  �[0m |
| �[0m 91      �[0m | �[0m 0.8294  �[0m | �[0m 0.03285 �[0m | �[0m 5.742   �[0m | �[0m 0.9747  �[0m |
epoch train_loss valid_loss accuracy time
0 0.384485 0.413711 0.816857 00:06
1 0.358383 0.382125 0.813055 00:06
2 0.349632 0.380644 0.825570 00:07
3 0.350005 0.363857 0.833492 00:07
4 0.341342 0.362533 0.829848 00:07
| �[0m 92      �[0m | �[0m 0.8298  �[0m | �[0m 4.112e-0�[0m | �[0m 8.0     �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.461806 0.454314 0.752218 00:11
1 0.429429 0.452721 0.768061 00:13
2 0.408792 0.422315 0.785013 00:12
3 0.400290 0.411134 0.805292 00:12
4 0.396974 0.409337 0.804658 00:12
| �[0m 93      �[0m | �[0m 0.8047  �[0m | �[0m 4.563   �[0m | �[0m 0.6868  �[0m | �[0m 2.461   �[0m |
epoch train_loss valid_loss accuracy time
0 0.464920 0.440214 0.772655 00:11
1 0.401611 0.400240 0.805767 00:11
2 0.391905 0.414512 0.798638 00:13
3 0.391561 0.407271 0.805608 00:13
4 0.387596 0.397586 0.814639 00:13
| �[0m 94      �[0m | �[0m 0.8146  �[0m | �[0m 4.697   �[0m | �[0m 3.412   �[0m | �[0m 2.514   �[0m |
epoch train_loss valid_loss accuracy time
0 0.417917 0.423588 0.755070 00:05
1 0.390158 0.396916 0.817649 00:07
2 0.374562 0.389764 0.828422 00:07
3 0.368317 0.385268 0.822719 00:07
4 0.373367 0.385092 0.827471 00:07
| �[0m 95      �[0m | �[0m 0.8275  �[0m | �[0m 5.04    �[0m | �[0m 0.4492  �[0m | �[0m 0.5899  �[0m |
epoch train_loss valid_loss accuracy time
0 0.368991 0.388251 0.822402 00:07
1 0.363111 0.399670 0.814797 00:08
2 0.370390 0.377726 0.818441 00:10
3 0.353493 0.368603 0.823828 00:10
4 0.355493 0.368281 0.823669 00:09
| �[0m 96      �[0m | �[0m 0.8237  �[0m | �[0m 0.6025  �[0m | �[0m 2.712   �[0m | �[0m 1.166   �[0m |
epoch train_loss valid_loss accuracy time
0 0.426161 0.417567 0.793726 00:07
1 0.397633 0.392852 0.818124 00:08
2 0.371073 0.371678 0.821768 00:09
3 0.351716 0.361607 0.831749 00:10
4 0.347291 0.359881 0.832066 00:09
| �[95m 97      �[0m | �[95m 0.8321  �[0m | �[95m 6.389   �[0m | �[95m 3.648   �[0m | �[95m 1.016   �[0m |
epoch train_loss valid_loss accuracy time
0 0.408937 0.440755 0.785805 00:12
1 0.369800 0.379622 0.819867 00:12
2 0.356048 0.370119 0.827630 00:11
3 0.354271 0.366091 0.826362 00:11
4 0.359324 0.366020 0.827313 00:12
| �[0m 98      �[0m | �[0m 0.8273  �[0m | �[0m 0.5927  �[0m | �[0m 1.715   �[0m | �[0m 2.847   �[0m |
epoch train_loss valid_loss accuracy time
0 0.443831 0.469775 0.753802 00:11
1 0.411636 0.423362 0.787864 00:11
2 0.411669 0.421874 0.797529 00:13
3 0.392242 0.395377 0.816382 00:12
4 0.383947 0.390270 0.818441 00:12
| �[0m 99      �[0m | �[0m 0.8184  �[0m | �[0m 6.602   �[0m | �[0m 2.266   �[0m | �[0m 2.535   �[0m |
epoch train_loss valid_loss accuracy time
0 0.413688 0.469193 0.811153 00:08
1 0.383908 0.401520 0.815589 00:08
2 0.363908 0.375178 0.824937 00:08
3 0.360694 0.366536 0.828739 00:08
4 0.341851 0.362886 0.830165 00:08
| �[0m 100     �[0m | �[0m 0.8302  �[0m | �[0m 3.099   �[0m | �[0m 6.058   �[0m | �[0m 1.058   �[0m |
epoch train_loss valid_loss accuracy time
0 0.404324 0.402310 0.807668 00:09
1 0.383628 0.428126 0.780577 00:09
2 0.360917 0.375198 0.826996 00:09
3 0.349922 0.367114 0.828264 00:09
4 0.338715 0.363762 0.830165 00:09
| �[0m 101     �[0m | �[0m 0.8302  �[0m | �[0m 4.268   �[0m | �[0m 5.28    �[0m | �[0m 1.397   �[0m |
epoch train_loss valid_loss accuracy time
0 0.381317 0.410546 0.815431 00:06
1 0.355632 0.424577 0.802440 00:06
2 0.361193 0.377675 0.818124 00:07
3 0.355751 0.363205 0.827313 00:07
4 0.338802 0.362403 0.827471 00:07
| �[0m 102     �[0m | �[0m 0.8275  �[0m | �[0m 4.735   �[0m | �[0m 3.398   �[0m | �[0m 0.01904 �[0m |
epoch train_loss valid_loss accuracy time
0 0.381150 0.398499 0.813213 00:07
1 0.370997 0.413040 0.798162 00:07
2 0.364154 0.369066 0.823352 00:07
3 0.354434 0.362925 0.825095 00:07
4 0.348736 0.362125 0.824461 00:07
| �[0m 103     �[0m | �[0m 0.8245  �[0m | �[0m 0.0     �[0m | �[0m 6.371   �[0m | �[0m 9.721e-0�[0m |
epoch train_loss valid_loss accuracy time
0 0.388378 0.412839 0.810361 00:06
1 0.377533 0.380432 0.822560 00:06
2 0.356242 0.372250 0.824937 00:07
3 0.346274 0.364236 0.830323 00:07
4 0.349273 0.362597 0.830482 00:07
| �[0m 104     �[0m | �[0m 0.8305  �[0m | �[0m 1.683   �[0m | �[0m 6.381   �[0m | �[0m 0.8564  �[0m |
epoch train_loss valid_loss accuracy time
0 0.410286 0.413506 0.801648 00:07
1 0.386787 0.397328 0.812262 00:07
2 0.365268 0.390419 0.825570 00:07
3 0.368566 0.386955 0.828264 00:07
4 0.362751 0.383124 0.830482 00:07
| �[0m 105     �[0m | �[0m 0.8305  �[0m | �[0m 6.387   �[0m | �[0m 2.333   �[0m | �[0m 0.4877  �[0m |
epoch train_loss valid_loss accuracy time
0 0.432347 0.450028 0.780894 00:17
1 0.414766 0.402565 0.809569 00:17
2 0.382495 0.382281 0.812421 00:18
3 0.366852 0.373158 0.822085 00:18
4 0.353692 0.368471 0.824461 00:18
| �[0m 106     �[0m | �[0m 0.8245  �[0m | �[0m 0.9452  �[0m | �[0m 6.902   �[0m | �[0m 3.991   �[0m |
epoch train_loss valid_loss accuracy time
0 0.397604 0.407966 0.814797 00:06
1 0.386755 0.382698 0.800063 00:07
2 0.361672 0.373610 0.823194 00:07
3 0.345655 0.363278 0.829848 00:07
4 0.349931 0.362171 0.830957 00:07
| �[0m 107     �[0m | �[0m 0.831   �[0m | �[0m 5.475   �[0m | �[0m 5.721   �[0m | �[0m 0.02659 �[0m |
| �[0m 108     �[0m | �[0m 0.8308  �[0m | �[0m 4.597   �[0m | �[0m 7.994   �[0m | �[0m 0.1318  �[0m |
epoch train_loss valid_loss accuracy time
0 0.391720 0.395161 0.819550 00:06
1 0.372859 0.387859 0.810203 00:07
2 0.367385 0.376952 0.812262 00:07
3 0.346810 0.365312 0.827155 00:07
4 0.341221 0.363528 0.829056 00:07
| �[0m 109     �[0m | �[0m 0.8291  �[0m | �[0m 2.391   �[0m | �[0m 4.915   �[0m | �[0m 0.8695  �[0m |
epoch train_loss valid_loss accuracy time
0 0.391635 0.402138 0.804658 00:08
1 0.402425 0.468013 0.803866 00:09
2 0.368443 0.372927 0.823669 00:09
3 0.361579 0.364913 0.826679 00:09
4 0.347989 0.363416 0.828580 00:09
| �[0m 110     �[0m | �[0m 0.8286  �[0m | �[0m 5.711   �[0m | �[0m 7.031   �[0m | �[0m 1.157   �[0m |
epoch train_loss valid_loss accuracy time
0 0.393225 0.414461 0.812738 00:07
1 0.377641 0.381910 0.820659 00:07
2 0.363979 0.375042 0.826838 00:07
3 0.349867 0.363547 0.830165 00:07
4 0.350035 0.364254 0.830006 00:07
| �[0m 111     �[0m | �[0m 0.83    �[0m | �[0m 3.867   �[0m | �[0m 7.351   �[0m | �[0m 0.7373  �[0m |
| �[0m 112     �[0m | �[0m 0.8275  �[0m | �[0m 5.568   �[0m | �[0m 0.8565  �[0m | �[0m 0.9522  �[0m |
| �[0m 113     �[0m | �[0m 0.8311  �[0m | �[0m 5.553   �[0m | �[0m 1.45    �[0m | �[0m 0.0     �[0m |
epoch train_loss valid_loss accuracy time
0 0.404215 0.421293 0.817807 00:10
1 0.380753 0.388533 0.803074 00:09
2 0.362475 0.381199 0.820342 00:09
3 0.353416 0.368431 0.824620 00:09
4 0.335854 0.365448 0.825570 00:09
| �[0m 114     �[0m | �[0m 0.8256  �[0m | �[0m 3.949   �[0m | �[0m 1.125   �[0m | �[0m 1.109   �[0m |
epoch train_loss valid_loss accuracy time
0 0.397351 0.482881 0.766160 00:08
1 0.373190 0.372999 0.819075 00:08
2 0.354075 0.368455 0.826046 00:08
3 0.350617 0.362793 0.830323 00:08
4 0.347381 0.361980 0.831274 00:08
| �[0m 115     �[0m | �[0m 0.8313  �[0m | �[0m 7.258   �[0m | �[0m 4.724   �[0m | �[0m 1.639   �[0m |
=============================================================
optimizer.max['target']
0.8428390622138977
{key: 2**int(value)
  for key, value in optimizer.max['params'].items()}
{'pow_n_a': 2, 'pow_n_d': 64, 'pow_n_steps': 1}

Out of memory dataset

If your dataset is so big it doesn't fit in memory, you can load a chunk of it each epoch.

df = pd.read_csv(path/'adult.csv')
df_main,df_valid = df.iloc[:-1000].copy(),df.iloc[-1000:].copy()
# choose size that fit in memory
dataset_size = 1000
# load chunk with your own code
def load_chunk():
    return df_main.sample(dataset_size).copy()
df_small = load_chunk()
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 
             'relationship', 'race', 'native-country', 'sex']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter()(range_of(df_small))
to = TabularPandas(df_small, procs, cat_names, cont_names, y_names="salary", y_block = CategoryBlock(), 
                   splits=None, do_setup=True)
# save the validation set
to_valid = to.new(df_valid)
to_valid.process()
val_dl = TabDataLoader(to_valid.train)
len(to.train)
1000
class ReloadCallback(Callback):
    def begin_epoch(self): 
        df_small = load_chunk()
        to_new = to.new(df_small)
        to_new.process()
        trn_dl = TabDataLoader(to_new.train)
        self.learn.dls = DataLoaders(trn_dl, val_dl).cuda()
dls = to.dataloaders()
emb_szs = get_emb_sz(to)
model = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=8, n_a=32, n_steps=1); 
opt_func = partial(Adam, wd=0.01, eps=1e-5)
learn = Learner(dls, model, CrossEntropyLossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])
learn.add_cb(ReloadCallback());
learn.fit_one_cycle(10)
epoch train_loss valid_loss accuracy time
0 0.587740 0.550544 0.756000 00:01
1 0.545411 0.515772 0.782000 00:01
2 0.484289 0.468586 0.813000 00:01
3 0.447111 0.435774 0.817000 00:01
4 0.449050 0.394715 0.819000 00:01
5 0.428863 0.382005 0.835000 00:01
6 0.382100 0.404258 0.826000 00:01
7 0.383915 0.376179 0.833000 00:01
8 0.389460 0.367857 0.834000 00:01
9 0.376486 0.367577 0.834000 00:01
Owner
Mikhail Grankin
Mikhail Grankin
YKKDetector For Python

YKKDetector OpenCVを利用した機械学習データをもとに、VRChatのスクリーンショットなどからYKKさん(もとい「幽狐族のお姉様」)を検出できるソフトウェアです。 マニュアル こちらから実行環境のセットアップから解説する詳細なマニュアルをご覧いただけます。 ライセンス 本ソフトウェア

あんふぃとらいと 5 Dec 07, 2021
RTSeg: Real-time Semantic Segmentation Comparative Study

Real-time Semantic Segmentation Comparative Study The repository contains the official TensorFlow code used in our papers: RTSEG: REAL-TIME SEMANTIC S

Mennatullah Siam 592 Nov 18, 2022
Repository to run object detection on a model trained on an autonomous driving dataset.

Autonomous Driving Object Detection on the Raspberry Pi 4 Description of Repository This repository contains code and instructions to configure the ne

Ethan 51 Nov 17, 2022
Dogs classification with Deep Metric Learning using some popular losses

Tsinghua Dogs classification with Deep Metric Learning 1. Introduction Tsinghua Dogs dataset Tsinghua Dogs is a fine-grained classification dataset fo

QuocThangNguyen 45 Nov 09, 2022
This repository is an unoffical PyTorch implementation of Medical segmentation in 3D and 2D.

Pytorch Medical Segmentation Read Chinese Introduction:Here! Recent Updates 2021.1.8 The train and test codes are released. 2021.2.6 A bug in dice was

EasyCV-Ellis 618 Dec 27, 2022
A paper using optimal transport to solve the graph matching problem.

GOAT A paper using optimal transport to solve the graph matching problem. https://arxiv.org/abs/2111.05366 Repo structure .github: Files specifying ho

neurodata 8 Jan 04, 2023
Video Instance Segmentation with a Propose-Reduce Paradigm (ICCV 2021)

Propose-Reduce VIS This repo contains the official implementation for the paper: Video Instance Segmentation with a Propose-Reduce Paradigm Huaijia Li

DV Lab 39 Nov 23, 2022
Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

THUNLP 5 Jun 16, 2022
Repository for the paper "PoseAug: A Differentiable Pose Augmentation Framework for 3D Human Pose Estimation", CVPR 2021.

PoseAug: A Differentiable Pose Augmentation Framework for 3D Human Pose Estimation Code repository for the paper: PoseAug: A Differentiable Pose Augme

Pyjcsx 328 Dec 17, 2022
RNN Predict Street Commercial Vitality

RNN-for-Predicting-Street-Vitality Code and dataset for Predicting the Vitality of Stores along the Street based on Business Type Sequence via Recurre

Zidong LIU 1 Dec 15, 2021
Python SDK for building, training, and deploying ML models

Overview of Kubeflow Fairing Kubeflow Fairing is a Python package that streamlines the process of building, training, and deploying machine learning (

Kubeflow 325 Dec 13, 2022
Highway networks implemented in PyTorch.

PyTorch Highway Networks Highway networks implemented in PyTorch. Just the MNIST example from PyTorch hacked to work with Highway layers. Todo Make th

Conner Vercellino 56 Dec 14, 2022
Graph-based community clustering approach to extract protein domains from a predicted aligned error matrix

Using a predicted aligned error matrix corresponding to an AlphaFold2 model , returns a series of lists of residue indices, where each list corresponds to a set of residues clustering together into a

Tristan Croll 24 Nov 23, 2022
Spatiotemporal resampling methods for mlr3

mlr3spatiotempcv Package website: release | dev Spatiotemporal resampling methods for mlr3. This package extends the mlr3 package framework with spati

45 Nov 21, 2022
QMagFace: Simple and Accurate Quality-Aware Face Recognition

Quality-Aware Face Recognition 26.11.2021 start readme QMagFace: Simple and Accurate Quality-Aware Face Recognition Research Paper Implementation - To

Philipp Terhörst 59 Jan 04, 2023
SimpleDepthEstimation - An unified codebase for NN-based monocular depth estimation methods

SimpleDepthEstimation Introduction This is an unified codebase for NN-based monocular depth estimation methods, the framework is based on detectron2 (

8 Dec 13, 2022
Lava-DL, but with PyTorch-Lightning flavour

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Sami BARCHID 4 Oct 31, 2022
[CVPR 2022] Deep Equilibrium Optical Flow Estimation

Deep Equilibrium Optical Flow Estimation This is the official repo for the paper Deep Equilibrium Optical Flow Estimation (CVPR 2022), by Shaojie Bai*

CMU Locus Lab 136 Dec 18, 2022
[NeurIPS 2020] This project provides a strong single-stage baseline for Long-Tailed Classification, Detection, and Instance Segmentation (LVIS).

A Strong Single-Stage Baseline for Long-Tailed Problems This project provides a strong single-stage baseline for Long-Tailed Classification (under Ima

Kaihua Tang 514 Dec 23, 2022
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data

Real-ESRGAN Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data Ported from https://github.com/xinntao/Real-ESRGAN Depend

Holy Wu 44 Dec 27, 2022