|
1 | 1 | import json
|
2 | 2 | import re
|
| 3 | +from abc import ABC |
3 | 4 | from typing import List, Tuple, Optional
|
4 | 5 |
|
5 | 6 | import numpy as np
|
6 | 7 | import torch
|
7 |
| -import torchvision |
8 | 8 | import torchaudio
|
| 9 | +import torchvision |
9 | 10 | from einops import rearrange
|
10 | 11 | from pytorch_lightning import Callback, Trainer, LightningModule
|
11 | 12 | from torch import Tensor
|
@@ -84,33 +85,37 @@ def resize_video(tensor: Tensor, size: Tuple[int, int], resize_method: str = "bi
|
84 | 85 | return F.interpolate(tensor, size=size, mode=resize_method)
|
85 | 86 |
|
86 | 87 |
|
87 |
| -def _get_Conv(PtConv): |
88 |
| - class _ConvNd(Module): |
| 88 | +class _ConvNd(Module, ABC): |
| 89 | + |
| 90 | + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, |
| 91 | + build_activation: Optional[callable] = None |
| 92 | + ): |
| 93 | + super().__init__() |
| 94 | + self.conv = self.PtConv( |
| 95 | + in_channels, out_channels, kernel_size, stride=stride, padding=padding |
| 96 | + ) |
| 97 | + if build_activation is not None: |
| 98 | + self.activation = build_activation() |
| 99 | + else: |
| 100 | + self.activation = None |
| 101 | + |
| 102 | + def forward(self, x: Tensor) -> Tensor: |
| 103 | + x = self.conv(x) |
| 104 | + if self.activation is not None: |
| 105 | + x = self.activation(x) |
| 106 | + return x |
| 107 | + |
89 | 108 |
|
90 |
| - def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, |
91 |
| - build_activation: Optional[callable] = None |
92 |
| - ): |
93 |
| - super().__init__() |
94 |
| - self.conv = PtConv( |
95 |
| - in_channels, out_channels, kernel_size, stride=stride, padding=padding |
96 |
| - ) |
97 |
| - if build_activation is not None: |
98 |
| - self.activation = build_activation() |
99 |
| - else: |
100 |
| - self.activation = None |
| 109 | +class Conv1d(_ConvNd): |
| 110 | + PtConv = torch.nn.Conv1d |
101 | 111 |
|
102 |
| - def forward(self, x: Tensor) -> Tensor: |
103 |
| - x = self.conv(x) |
104 |
| - if self.activation is not None: |
105 |
| - x = self.activation(x) |
106 |
| - return x |
107 | 112 |
|
108 |
| - return _ConvNd |
| 113 | +class Conv2d(_ConvNd): |
| 114 | + PtConv = torch.nn.Conv2d |
109 | 115 |
|
110 | 116 |
|
111 |
| -Conv1d = _get_Conv(torch.nn.Conv1d) |
112 |
| -Conv2d = _get_Conv(torch.nn.Conv2d) |
113 |
| -Conv3d = _get_Conv(torch.nn.Conv3d) |
| 117 | +class Conv3d(_ConvNd): |
| 118 | + PtConv = torch.nn.Conv3d |
114 | 119 |
|
115 | 120 |
|
116 | 121 | def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
|
|
0 commit comments