zamba.models.slowfast_models¶
Classes¶
SlowFast
¶
Bases: ZambaVideoClassificationLightningModule
Pretrained SlowFast model for fine-tuning with the following architecture:
Input -> SlowFast Base (including trainable Backbone) -> Res Basic Head -> Output
Attributes:
Name | Type | Description |
---|---|---|
backbone |
torch.nn.Module
|
When scheduling the backbone to train with the
|
base |
torch.nn.Module
|
The entire model prior to the head. |
head |
torch.nn.Module
|
The trainable head. |
_backbone_output_dim |
int
|
Dimensionality of the backbone output (and head input). |
Source code in /home/runner/work/zamba/zamba/zamba/models/slowfast_models.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
|
Attributes¶
backbone = model.backbone
instance-attribute
¶
backbone_mode = backbone_mode
instance-attribute
¶
base = model.base
instance-attribute
¶
head = head
instance-attribute
¶
lr = lr
instance-attribute
¶
model_class = type(self).__name__
instance-attribute
¶
num_classes = len(species)
instance-attribute
¶
scheduler = torch.optim.lr_scheduler.__dict__[scheduler]
instance-attribute
¶
scheduler_params = scheduler_params
instance-attribute
¶
species = species
instance-attribute
¶
test_step_outputs = []
instance-attribute
¶
training_step_outputs = []
instance-attribute
¶
validation_step_outputs = []
instance-attribute
¶
Functions¶
__init__(backbone_mode: str = 'train', post_backbone_dropout: Optional[float] = None, output_with_global_average: bool = True, head_dropout_rate: Optional[float] = None, head_hidden_layer_sizes: Optional[Tuple[int]] = None, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs: Optional[Union[os.PathLike, str]])
¶
Initializes the SlowFast model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
backbone_mode |
str
|
If "eval", treat the backbone as a feature extractor and set to evaluation mode in all forward passes. |
'train'
|
post_backbone_dropout |
float
|
Dropout that operates on the output of the backbone + pool (before the fully-connected layer in the head). |
None
|
output_with_global_average |
bool
|
If True, apply an adaptive average pooling operation after the fully-connected layer in the head. |
True
|
head_dropout_rate |
float
|
Optional dropout rate applied after backbone and between projection layers in the head. |
None
|
head_hidden_layer_sizes |
tuple of int
|
If not None, the size of hidden layers in the head multilayer perceptron. |
None
|
finetune_from |
pathlike or str
|
If not None, load an existing model from the path and resume training from an existing model. |
None
|
Source code in /home/runner/work/zamba/zamba/zamba/models/slowfast_models.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
aggregate_step_outputs(outputs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]
staticmethod
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
206 207 208 209 210 211 212 213 214 |
|
compute_and_log_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray, subset: str)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
|
configure_optimizers()
¶
Setup the Adam optimizer. Note, that this function also can return a lr scheduler, which is usually useful for training video models.
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
|
forward(x, *args, **kwargs)
¶
Source code in /home/runner/work/zamba/zamba/zamba/models/slowfast_models.py
111 112 113 114 115 116 |
|
from_disk(path: os.PathLike, **kwargs: os.PathLike)
classmethod
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
305 306 307 308 |
|
initialize_from_torchub()
¶
Loads SlowFast model from torchhub and prepares ZambaVideoClassificationLightningModule by removing the head and setting the backbone and base.
Source code in /home/runner/work/zamba/zamba/zamba/models/slowfast_models.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
|
on_test_epoch_end()
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
259 260 261 262 |
|
on_train_start()
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
162 163 164 165 166 167 168 169 170 171 172 173 |
|
on_validation_epoch_end()
¶
Aggregates validation_step outputs to compute and log the validation macro F1 and top K metrics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
outputs |
List[dict]
|
list of output dictionaries from each validation step containing y_pred and y_true. |
required |
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
247 248 249 250 251 252 253 254 255 256 257 |
|
predict_step(batch, batch_idx, dataloader_idx: Optional[int] = None)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
264 265 266 267 268 |
|
test_step(batch, batch_idx)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
201 202 203 204 |
|
to_disk(path: os.PathLike)
¶
Save out model weights to a checkpoint file on disk.
Note: this does not include callbacks, optimizer_states, or lr_schedulers.
To include those, use Trainer.save_checkpoint()
instead.
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
|
training_step(batch, batch_idx)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
175 176 177 178 179 180 181 |
|
validation_step(batch, batch_idx)
¶
Source code in /home/runner/work/zamba/zamba/zamba/pytorch_lightning/utils.py
196 197 198 199 |
|