【 ? 】网络迷途
其实是某比赛的原题,被我搬过来了,这个 wp 也是(
目标是从给出的 output 反推输入。
观察网络结构,前三层为 Conv2d、Conv2d、MaxPool2d。由于 kernel 很小,图像经过三层处理,应 该仍然能肉眼看出 flag。
因此考虑如何恢复出 Linear 层之前的图片。显然,要完全恢复 10 个 channel 是不可行的,因为在 Linear 层之后,经历了 1x1 卷积,将 10 个 channel 合并成了一个 channel。
注意到 1x1 卷积可以认为是各 channel 的加权平均。所以,我们不再尝试恢复出 10 个 channel,而是 去恢复这 10 个 channel 的「均值」。
因此,解题过程为:
- 通过 sigmoid 的反函数,恢复 1x1 卷积层之后的图像
- 减去 1x1 卷积的 bias、除以 1x1 卷积的 weight 均值,获得 Linear 层之后的「均值输出」
- 减去 Linear 层的 bias、乘以 Linear 层的 weight 矩阵之伪逆矩阵,获得「均值图像」
- 从图像中观察出 flag,其实后来发现后面有几个字看不清,不过貌似没人到这一步,就算了。
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
net = torch.load('net.pt')
y = np.array(Image.open('enc.png').convert('L')) / 255
y = torch.Tensor(y).reshape([1, 285, 2850])
y = -torch.log((1 / y) - 1)
w = net[-3].weight.detach()
x = (((y - net[-2].bias.detach()) / net[-2].weight.detach().sum()
-net[-3].bias.detach()) @ w.T.pinverse())
plt.imshow(x[0, :, :], cmap='gray')
plt.show()