Skip to content

zamba.pytorch.transforms

Attributes

imagenet_normalization_values = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) module-attribute

Classes

ConvertHWCtoCHW

Bases: torch.nn.Module

Convert tensor from (0:H, 1:W, 2:C) to (2:C, 0:H, 1:W)

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
30
31
32
33
34
class ConvertHWCtoCHW(torch.nn.Module):
    """Convert tensor from (0:H, 1:W, 2:C) to (2:C, 0:H, 1:W)"""

    def forward(self, vid: torch.Tensor) -> torch.Tensor:
        return vid.permute(2, 0, 1)

Functions

forward(vid: torch.Tensor) -> torch.Tensor
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
33
34
def forward(self, vid: torch.Tensor) -> torch.Tensor:
    return vid.permute(2, 0, 1)

ConvertTCHWtoCTHW

Bases: torch.nn.Module

Convert tensor from (T, C, H, W) to (C, T, H, W)

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
23
24
25
26
27
class ConvertTCHWtoCTHW(torch.nn.Module):
    """Convert tensor from (T, C, H, W) to (C, T, H, W)"""

    def forward(self, vid: torch.Tensor) -> torch.Tensor:
        return vid.permute(1, 0, 2, 3)

Functions

forward(vid: torch.Tensor) -> torch.Tensor
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
26
27
def forward(self, vid: torch.Tensor) -> torch.Tensor:
    return vid.permute(1, 0, 2, 3)

ConvertTHWCtoCTHW

Bases: torch.nn.Module

Convert tensor from (0:T, 1:H, 2:W, 3:C) to (3:C, 0:T, 1:H, 2:W)

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
 9
10
11
12
13
class ConvertTHWCtoCTHW(torch.nn.Module):
    """Convert tensor from (0:T, 1:H, 2:W, 3:C) to (3:C, 0:T, 1:H, 2:W)"""

    def forward(self, vid: torch.Tensor) -> torch.Tensor:
        return vid.permute(3, 0, 1, 2)

Functions

forward(vid: torch.Tensor) -> torch.Tensor
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
12
13
def forward(self, vid: torch.Tensor) -> torch.Tensor:
    return vid.permute(3, 0, 1, 2)

ConvertTHWCtoTCHW

Bases: torch.nn.Module

Convert tensor from (T, H, W, C) to (T, C, H, W)

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
16
17
18
19
20
class ConvertTHWCtoTCHW(torch.nn.Module):
    """Convert tensor from (T, H, W, C) to (T, C, H, W)"""

    def forward(self, vid: torch.Tensor) -> torch.Tensor:
        return vid.permute(0, 3, 1, 2)

Functions

forward(vid: torch.Tensor) -> torch.Tensor
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
19
20
def forward(self, vid: torch.Tensor) -> torch.Tensor:
    return vid.permute(0, 3, 1, 2)

PackSlowFastPathways

Bases: torch.nn.Module

Creates the slow and fast pathway inputs for the slowfast model.

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class PackSlowFastPathways(torch.nn.Module):
    """Creates the slow and fast pathway inputs for the slowfast model."""

    def __init__(self, alpha: int = 4):
        super().__init__()
        self.alpha = alpha

    def forward(self, frames: torch.Tensor):
        fast_pathway = frames
        # Perform temporal sampling from the fast pathway.
        slow_pathway = torch.index_select(
            frames,
            1,
            torch.linspace(0, frames.shape[1] - 1, frames.shape[1] // self.alpha).long(),
        )
        frame_list = [slow_pathway, fast_pathway]
        return frame_list

Attributes

alpha = alpha instance-attribute

Functions

__init__(alpha: int = 4)
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
91
92
93
def __init__(self, alpha: int = 4):
    super().__init__()
    self.alpha = alpha
forward(frames: torch.Tensor)
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
 95
 96
 97
 98
 99
100
101
102
103
104
def forward(self, frames: torch.Tensor):
    fast_pathway = frames
    # Perform temporal sampling from the fast pathway.
    slow_pathway = torch.index_select(
        frames,
        1,
        torch.linspace(0, frames.shape[1] - 1, frames.shape[1] // self.alpha).long(),
    )
    frame_list = [slow_pathway, fast_pathway]
    return frame_list

PadDimensions

Bases: torch.nn.Module

Pads a tensor to ensure a fixed output dimension for a give axis.

Attributes:

Name Type Description
dimension_sizes

A tuple of int or None the same length as the number of dimensions in the input tensor. If int, pad that dimension to at least that size. If None, do not pad.

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class PadDimensions(torch.nn.Module):
    """Pads a tensor to ensure a fixed output dimension for a give axis.

    Attributes:
        dimension_sizes: A tuple of int or None the same length as the number of dimensions in the
            input tensor. If int, pad that dimension to at least that size. If None, do not pad.
    """

    def __init__(self, dimension_sizes: Tuple[Optional[int]]):
        super().__init__()
        self.dimension_sizes = dimension_sizes

    @staticmethod
    def compute_left_and_right_pad(original_size: int, padded_size: int) -> Tuple[int, int]:
        """Computes left and right pad size.

        Args:
            original_size (list, int): The original tensor size
            padded_size (list, int): The desired tensor size

        Returns:
           Tuple[int]: Pad size for right and left. For odd padding size, the right = left + 1
        """
        if original_size >= padded_size:
            return 0, 0
        pad = padded_size - original_size
        quotient, remainder = divmod(pad, 2)
        return quotient, quotient + remainder

    def forward(self, vid: torch.Tensor) -> torch.Tensor:
        padding = tuple(
            itertools.chain.from_iterable(
                (0, 0)
                if padded_size is None
                else self.compute_left_and_right_pad(original_size, padded_size)
                for original_size, padded_size in zip(vid.shape, self.dimension_sizes)
            )
        )
        return torch.nn.functional.pad(vid, padding[::-1])

Attributes

dimension_sizes = dimension_sizes instance-attribute

Functions

__init__(dimension_sizes: Tuple[Optional[int]])
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
55
56
57
def __init__(self, dimension_sizes: Tuple[Optional[int]]):
    super().__init__()
    self.dimension_sizes = dimension_sizes
compute_left_and_right_pad(original_size: int, padded_size: int) -> Tuple[int, int] staticmethod

Computes left and right pad size.

Parameters:

Name Type Description Default
original_size (list, int)

The original tensor size

required
padded_size (list, int)

The desired tensor size

required

Returns:

Type Description
Tuple[int, int]

Tuple[int]: Pad size for right and left. For odd padding size, the right = left + 1

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@staticmethod
def compute_left_and_right_pad(original_size: int, padded_size: int) -> Tuple[int, int]:
    """Computes left and right pad size.

    Args:
        original_size (list, int): The original tensor size
        padded_size (list, int): The desired tensor size

    Returns:
       Tuple[int]: Pad size for right and left. For odd padding size, the right = left + 1
    """
    if original_size >= padded_size:
        return 0, 0
    pad = padded_size - original_size
    quotient, remainder = divmod(pad, 2)
    return quotient, quotient + remainder
forward(vid: torch.Tensor) -> torch.Tensor
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
76
77
78
79
80
81
82
83
84
85
def forward(self, vid: torch.Tensor) -> torch.Tensor:
    padding = tuple(
        itertools.chain.from_iterable(
            (0, 0)
            if padded_size is None
            else self.compute_left_and_right_pad(original_size, padded_size)
            for original_size, padded_size in zip(vid.shape, self.dimension_sizes)
        )
    )
    return torch.nn.functional.pad(vid, padding[::-1])

Uint8ToFloat

Bases: torch.nn.Module

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
37
38
39
class Uint8ToFloat(torch.nn.Module):
    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return tensor / 255.0

Functions

forward(tensor: torch.Tensor) -> torch.Tensor
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
38
39
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
    return tensor / 255.0

VideotoImg

Bases: torch.nn.Module

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
42
43
44
class VideotoImg(torch.nn.Module):
    def forward(self, vid: torch.Tensor) -> torch.Tensor:
        return vid.squeeze(0)

Functions

forward(vid: torch.Tensor) -> torch.Tensor
Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
43
44
def forward(self, vid: torch.Tensor) -> torch.Tensor:
    return vid.squeeze(0)

Functions

slowfast_transforms()

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
128
129
130
131
132
133
134
135
136
137
138
def slowfast_transforms():
    return transforms.Compose(
        [
            ConvertTHWCtoTCHW(),
            Uint8ToFloat(),
            Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]),
            ConvertTCHWtoCTHW(),
            PadDimensions((None, 32, None, None)),
            PackSlowFastPathways(),
        ]
    )

zamba_image_model_transforms(single_frame = False, normalization_values = imagenet_normalization_values, channels_first = False)

Source code in /home/runner/work/zamba/zamba/zamba/pytorch/transforms.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def zamba_image_model_transforms(
    single_frame=False, normalization_values=imagenet_normalization_values, channels_first=False
):
    img_transforms = [
        ConvertTHWCtoTCHW(),
        Uint8ToFloat(),
        transforms.Normalize(**imagenet_normalization_values),
    ]

    if single_frame:
        img_transforms += [VideotoImg()]  # squeeze dim

    if channels_first:
        img_transforms += [ConvertTCHWtoCTHW()]

    return transforms.Compose(img_transforms)