zamba.models.efficientnet_models¶
Classes¶
TimeDistributedEfficientNet
¶
Bases: ZambaVideoClassificationLightningModule
Source code in /home/runner/work/zamba/zamba/zamba/models/efficientnet_models.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
|
Attributes¶
backbone = torch.nn.ModuleList([efficientnet.get_submodule('blocks.5'), efficientnet.conv_head, efficientnet.bn2, efficientnet.global_pool])
instance-attribute
¶
base = TimeDistributed(efficientnet, tdim=1)
instance-attribute
¶
classifier = nn.Sequential(nn.Linear(num_backbone_final_features, 256), nn.Dropout(0.2), nn.ReLU(), nn.Linear(256, 64), nn.Flatten(), nn.Linear(64 * num_frames, self.num_classes))
instance-attribute
¶
lr = lr
instance-attribute
¶
model_class = type(self).__name__
instance-attribute
¶
num_classes = len(species)
instance-attribute
¶
scheduler = torch.optim.lr_scheduler.__dict__[scheduler]
instance-attribute
¶
scheduler_params = scheduler_params
instance-attribute
¶
species = species
instance-attribute
¶
test_step_outputs = []
instance-attribute
¶
training_step_outputs = []
instance-attribute
¶
validation_step_outputs = []
instance-attribute
¶
Functions¶
__init__(num_frames = 16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs: Optional[Union[os.PathLike, str]])
¶
Source code in /home/runner/work/zamba/zamba/zamba/models/efficientnet_models.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
|
aggregate_step_outputs(outputs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]
staticmethod
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
206 207 208 209 210 211 212 213 214 |
|
compute_and_log_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray, subset: str)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
|
configure_optimizers()
¶
Setup the Adam optimizer. Note, that this function also can return a lr scheduler, which is usually useful for training video models.
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
|
forward(x)
¶
Source code in /home/runner/work/zamba/zamba/zamba/models/efficientnet_models.py
60 61 62 63 |
|
from_disk(path: os.PathLike, **kwargs: os.PathLike)
classmethod
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
305 306 307 308 |
|
on_test_epoch_end()
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
259 260 261 262 |
|
on_train_start()
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
162 163 164 165 166 167 168 169 170 171 172 173 |
|
on_validation_epoch_end()
¶
Aggregates validation_step outputs to compute and log the validation macro F1 and top K metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
outputs |
List[dict]
|
list of output dictionaries from each validation step containing y_pred and y_true. |
required |
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
247 248 249 250 251 252 253 254 255 256 257 |
|
predict_step(batch, batch_idx, dataloader_idx: Optional[int] = None)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
264 265 266 267 268 |
|
test_step(batch, batch_idx)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
201 202 203 204 |
|
to_disk(path: os.PathLike)
¶
Save out model weights to a checkpoint file on disk.
Note: this does not include callbacks, optimizer_states, or lr_schedulers.
To include those, use Trainer.save_checkpoint()
instead.
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
|
training_step(batch, batch_idx)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
175 176 177 178 179 180 181 |
|
validation_step(batch, batch_idx)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
196 197 198 199 |
|