Skip to content

Commit c4b4d32

Browse files
committed
fix pickle problem
1 parent 2892df2 commit c4b4d32

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

utils.py

+28-23
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
import re
3+
from abc import ABC
34
from typing import List, Tuple, Optional
45

56
import numpy as np
67
import torch
7-
import torchvision
88
import torchaudio
9+
import torchvision
910
from einops import rearrange
1011
from pytorch_lightning import Callback, Trainer, LightningModule
1112
from torch import Tensor
@@ -84,33 +85,37 @@ def resize_video(tensor: Tensor, size: Tuple[int, int], resize_method: str = "bi
8485
return F.interpolate(tensor, size=size, mode=resize_method)
8586

8687

87-
def _get_Conv(PtConv):
88-
class _ConvNd(Module):
88+
class _ConvNd(Module, ABC):
89+
90+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0,
91+
build_activation: Optional[callable] = None
92+
):
93+
super().__init__()
94+
self.conv = self.PtConv(
95+
in_channels, out_channels, kernel_size, stride=stride, padding=padding
96+
)
97+
if build_activation is not None:
98+
self.activation = build_activation()
99+
else:
100+
self.activation = None
101+
102+
def forward(self, x: Tensor) -> Tensor:
103+
x = self.conv(x)
104+
if self.activation is not None:
105+
x = self.activation(x)
106+
return x
107+
89108

90-
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0,
91-
build_activation: Optional[callable] = None
92-
):
93-
super().__init__()
94-
self.conv = PtConv(
95-
in_channels, out_channels, kernel_size, stride=stride, padding=padding
96-
)
97-
if build_activation is not None:
98-
self.activation = build_activation()
99-
else:
100-
self.activation = None
109+
class Conv1d(_ConvNd):
110+
PtConv = torch.nn.Conv1d
101111

102-
def forward(self, x: Tensor) -> Tensor:
103-
x = self.conv(x)
104-
if self.activation is not None:
105-
x = self.activation(x)
106-
return x
107112

108-
return _ConvNd
113+
class Conv2d(_ConvNd):
114+
PtConv = torch.nn.Conv2d
109115

110116

111-
Conv1d = _get_Conv(torch.nn.Conv1d)
112-
Conv2d = _get_Conv(torch.nn.Conv2d)
113-
Conv3d = _get_Conv(torch.nn.Conv3d)
117+
class Conv3d(_ConvNd):
118+
PtConv = torch.nn.Conv3d
114119

115120

116121
def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):

0 commit comments

Comments
 (0)