Skip to content

Commit d725101

Browse files
committed
feat: add warm_start matching scikit-learn
1 parent a0fd306 commit d725101

File tree

5 files changed

+63
-24
lines changed

5 files changed

+63
-24
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,7 @@ loaded_clf.load_model(saved_filepath)
360360
/!\ TabNetPretrainer Only : Percentage of input features to mask during pretraining.
361361

362362
Should be between 0 and 1. The bigger the harder the reconstruction task is.
363+
364+
- `warm_start` : bool (default=False)
365+
In order to match scikit-learn API, this is set to False.
366+
It allows to fit twice the same model and start from a warm start.

census_example.ipynb

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,18 @@
158158
"metadata": {},
159159
"outputs": [],
160160
"source": [
161-
"clf = TabNetClassifier(cat_idxs=cat_idxs,\n",
162-
" cat_dims=cat_dims,\n",
163-
" cat_emb_dim=1,\n",
164-
" optimizer_fn=torch.optim.Adam,\n",
165-
" optimizer_params=dict(lr=2e-2),\n",
166-
" scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n",
167-
" \"gamma\":0.9},\n",
168-
" scheduler_fn=torch.optim.lr_scheduler.StepLR,\n",
169-
" mask_type='entmax' # \"sparsemax\"\n",
161+
"tabnet_params = {\"cat_idxs\":cat_idxs,\n",
162+
" \"cat_dims\":cat_dims,\n",
163+
" \"cat_emb_dim\":1,\n",
164+
" \"optimizer_fn\":torch.optim.Adam,\n",
165+
" \"optimizer_params\":dict(lr=2e-2),\n",
166+
" \"scheduler_params\":{\"step_size\":50, # how to use learning rate scheduler\n",
167+
" \"gamma\":0.9},\n",
168+
" \"scheduler_fn\":torch.optim.lr_scheduler.StepLR,\n",
169+
" \"mask_type\":'entmax' # \"sparsemax\"\n",
170+
" }\n",
171+
"\n",
172+
"clf = TabNetClassifier(**tabnet_params\n",
170173
" )"
171174
]
172175
},
@@ -199,7 +202,7 @@
199202
"metadata": {},
200203
"outputs": [],
201204
"source": [
202-
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
205+
"max_epochs = 100 if not os.getenv(\"CI\", False) else 2"
203206
]
204207
},
205208
{
@@ -210,17 +213,23 @@
210213
},
211214
"outputs": [],
212215
"source": [
213-
"clf.fit(\n",
214-
" X_train=X_train, y_train=y_train,\n",
215-
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
216-
" eval_name=['train', 'valid'],\n",
217-
" eval_metric=['auc'],\n",
218-
" max_epochs=max_epochs , patience=20,\n",
219-
" batch_size=1024, virtual_batch_size=128,\n",
220-
" num_workers=0,\n",
221-
" weights=1,\n",
222-
" drop_last=False\n",
223-
") "
216+
"# This illustrates the warm_start=False behaviour\n",
217+
"save_history = []\n",
218+
"for _ in range(2):\n",
219+
" clf.fit(\n",
220+
" X_train=X_train, y_train=y_train,\n",
221+
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
222+
" eval_name=['train', 'valid'],\n",
223+
" eval_metric=['auc'],\n",
224+
" max_epochs=max_epochs , patience=20,\n",
225+
" batch_size=1024, virtual_batch_size=128,\n",
226+
" num_workers=0,\n",
227+
" weights=1,\n",
228+
" drop_last=False\n",
229+
" )\n",
230+
" save_history.append(clf.history[\"valid_auc\"])\n",
231+
" \n",
232+
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))"
224233
]
225234
},
226235
{

pytorch_tabnet/abstract_model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
create_dataloaders,
1414
define_device,
1515
ComplexEncoder,
16-
check_input
16+
check_input,
17+
check_warm_start
1718
)
1819
from pytorch_tabnet.callbacks import (
1920
CallbackContainer,
@@ -73,6 +74,10 @@ def __post_init__(self):
7374
if self.verbose != 0:
7475
warnings.warn(f"Device used : {self.device}")
7576

77+
# create deep copies of mutable parameters
78+
self.optimizer_fn = copy.deepcopy(self.optimizer_fn)
79+
self.scheduler_fn = copy.deepcopy(self.scheduler_fn)
80+
7681
def __update__(self, **kwargs):
7782
"""
7883
Updates parameters.
@@ -120,6 +125,7 @@ def fit(
120125
callbacks=None,
121126
pin_memory=True,
122127
from_unsupervised=None,
128+
warm_start=False
123129
):
124130
"""Train a neural network stored in self.network
125131
Using train_dataloader for training data and
@@ -163,6 +169,8 @@ def fit(
163169
Whether to set pin_memory to True or False during training
164170
from_unsupervised: unsupervised trained model
165171
Use a previously self supervised model as starting weights
172+
warm_start: bool
173+
If True, current model parameters are used to start training
166174
"""
167175
# update model name
168176

@@ -184,6 +192,7 @@ def fit(
184192
self.loss_fn = loss_fn
185193

186194
check_input(X_train)
195+
check_warm_start(warm_start, from_unsupervised)
187196

188197
self.update_fit_params(
189198
X_train,
@@ -203,7 +212,8 @@ def fit(
203212
# Update parameters to match self pretraining
204213
self.__update__(**from_unsupervised.get_params())
205214

206-
if not hasattr(self, "network"):
215+
if not hasattr(self, "network") or not warm_start:
216+
# model has never been fitted before of warm_start is False
207217
self._set_network()
208218
self._update_network_params()
209219
self._set_metrics(eval_metric, eval_names)
@@ -542,6 +552,7 @@ def _predict_batch(self, X):
542552

543553
def _set_network(self):
544554
"""Setup the network and explain matrix."""
555+
torch.manual_seed(self.seed)
545556
self.network = tab_network.TabNet(
546557
self.input_dim,
547558
self.output_dim,

pytorch_tabnet/pretraining.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def fit(
5858
drop_last=True,
5959
callbacks=None,
6060
pin_memory=True,
61+
warm_start=False
6162
):
6263
"""Train a neural network stored in self.network
6364
Using train_dataloader for training data and
@@ -130,8 +131,10 @@ def fit(
130131
X_train, eval_set
131132
)
132133

133-
if not hasattr(self, 'network'):
134+
if not hasattr(self, "network") or not warm_start:
135+
# model has never been fitted before of warm_start is False
134136
self._set_network()
137+
135138
self._update_network_params()
136139
self._set_metrics(eval_names)
137140
self._set_optimizer()
@@ -168,6 +171,7 @@ def _set_network(self):
168171
"""Setup the network and explain matrix."""
169172
if not hasattr(self, 'pretraining_ratio'):
170173
self.pretraining_ratio = 0.5
174+
torch.manual_seed(self.seed)
171175
self.network = tab_network.TabNetPretraining(
172176
self.input_dim,
173177
pretraining_ratio=self.pretraining_ratio,

pytorch_tabnet/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
from sklearn.utils import check_array
88
import pandas as pd
9+
import warnings
910

1011

1112
class TorchDataset(Dataset):
@@ -349,4 +350,14 @@ def check_input(X):
349350
err_message = "Pandas DataFrame are not supported: apply X.values when calling fit"
350351
raise(ValueError, err_message)
351352
check_array(X)
353+
354+
355+
def check_warm_start(warm_start, from_unsupervised):
356+
"""
357+
Gives a warning about ambiguous usage of the two parameters.
358+
"""
359+
if warm_start and from_unsupervised is not None:
360+
warn_msg = "warm_start=True and from_unsupervised != None: "
361+
warn_msg = "warm_start will be ignore, training will start from unsupervised weights"
362+
warnings.warn(warn_msg)
352363
return

0 commit comments

Comments
 (0)