-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
47 lines (40 loc) · 1.52 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from torch import nn
class TuduiModel(nn.Module):
def __init__(self, num_classes = 20, num_pred_bbox = 1):
super(TuduiModel, self).__init__()
self.num_classes = num_classes
self.num_pred_bbox = num_pred_bbox
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fc_layers = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 52 * 52, 512),
nn.ReLU(),
nn.Linear(512, self.num_pred_bbox * (num_classes + 4))
)
def forward(self, x):
x = self.conv_layers(x)
x = self.fc_layers(x)
# 分离类别预测和边界框预测
class_pred = x[:, :self.num_classes] # 类别预测
box_pred = torch.sigmoid(x[:, self.num_classes:]) # 边界框预测 (归一化到0-1)
# 组合预测结果
output = torch.cat([class_pred, box_pred], dim=1)
return output
if __name__ == '__main__':
# 测试代码
input = torch.randn(1, 3, 418, 418)
model = TuduiModel(16)
output = model(input)
print("Output shape:", output.shape)
print("Sample output:", output[0])