计算卷积带宽#
假设输入数据 shape: \((b, c_i, h_i, w_i)\),卷积核为 \((c_o, c_i, h_k, w_k)\),padding 为 \((p_0, p_1, p_2, p_3)\),strides 为:\((s_h, s_w)\),则有:
\[\begin{split}
\begin{align}
h_o = \lfloor \dfrac{h_i + 2 p_0 - k_h}{s_h} + 1 \rfloor \\
w_o = \lfloor \dfrac{w_i + 2 p_1 - k_h}{s_h} + 1 \rfloor
\end{align}
\end{split}\]
备注
需要考虑一种约束:
\[\begin{split}
\begin{cases}
h_i + 2 p_0 \ge k_h \\
w_i + 2 p_1 \ge w_h \\
s_h \ge 1 \\
s_w \ge 1
\end{cases}
\end{split}\]
这个条件可能会被忽略。
假设水平方向滑动 \(t_w\) 次,竖直方向滑动 \(t_h\) 次,则输入数据总计搬运:\(\text{inp\_load\_nbytes} = (t_ht_w) (k_h k_w)\),重叠区域面积为:\(\text{overlap} = (k_h k_w - s_h s_w)((t_h-1) (t_w-1))\),有效输入:\(((t_h - 1)*k_h + s_h)((t_w - 1)*k_w + s_w)\)。
from dataclasses import dataclass
@dataclass
class Conv2dThroughput:
k_h: int # 卷积核高度
k_w: int # 卷积核宽度
s_h: int # 卷积核纵向移动步长
s_w: int # 卷积核横向移动步长
t_h: int # 卷积核纵向移动次数
t_w: int # 卷积核横向移动次数
@property
def inp_load_nbytes(self):
"""输入数据搬运带宽
每次搬运数据面积为 self.k_h * self.k_w,
总计搬运 self.t_h * self.t_w 次
"""
return self.t_h * self.t_w * self.k_h * self.k_w
@property
def inp_overlap_nbytes(self):
"""输入数据搬运重叠带宽
考虑横纵两个方向,同时搬运一次 overlap 面积是固定的,
即 self.k_h * self.k_w - self.s_h * self.s_w,
总计搬运次数为 (self.t_h - 1) * (self.t_w - 1)
"""
overlap_block = self.k_h * self.k_w - self.s_h * self.s_w
time_block = (self.t_h - 1) * (self.t_w - 1)
return overlap_block * time_block
@property
def inp_valid_nbytes(self):
"""数据数据实际占用面积
"""
block_h = (self.t_h - 1) * self.k_h + s_h
block_w = (self.t_w - 1) * self.k_w + s_w
return block_h * block_w
import plotly.graph_objects as go
fig = go.Figure()
# fig.add_trace(go.Scatter(
# x=[1.5, 4.5],
# y=[0.75, 0.75],
# text=["Unfilled Rectangle", "Filled Rectangle"],
# mode="text",
# ))
# 设置 axes 属性
fig.update_xaxes(range=[0, 8], showgrid=True, side="top")
fig.update_yaxes(range=[8, 0], autorange="reversed", showgrid=True)
# Add shapes
fig.add_shape(type="rect",
x0=2, y0=0, x1=5, y1=3,
line=dict(color="red", width=7),
fillcolor="red",
opacity=0.75
)
fig.add_shape(type="rect",
x0=4, y0=0, x1=7, y1=3,
line=dict(color="green", width=7),
fillcolor="green",
opacity=0.5
)
fig.add_shape(type="rect",
x0=0, y0=2, x1=3, y1=5,
line=dict(color="red", width=7),
fillcolor="red",
opacity=0.5
)
fig.add_shape(type="rect",
x0=0, y0=4, x1=3, y1=7,
line=dict(color="green", width=7),
fillcolor="green",
opacity=0.5
)
fig.add_shape(type="rect",
x0=2, y0=2, x1=5, y1=5,
line=dict(color="red", width=7),
fillcolor="red",
opacity=0.5
)
fig.add_shape(type="rect",
x0=2, y0=4, x1=5, y1=7,
line=dict(color="green", width=7),
fillcolor="green",
opacity=0.5
)
fig.add_shape(type="rect",
x0=4, y0=2, x1=7, y1=5,
line=dict(color="red", width=7),
fillcolor="red",
opacity=0.5
)
fig.add_shape(type="rect",
x0=4, y0=4, x1=7, y1=7,
line=dict(color="green", width=7),
fillcolor="green",
opacity=0.5
)
fig.add_shape(type="rect",
x0=0, y0=0, x1=7, y1=7,
line=dict(color="gray", width=7),
fillcolor="wheat",
opacity=0.2
)
fig.add_shape(type="rect",
x0=0, y0=0, x1=3, y1=3,
line=dict(color="LightSkyBlue", width=7),
fillcolor="LightSkyBlue",
opacity=0.5
)
fig.add_shape(type="line",
x0=3, y0=0, x1=3, y1=8,
line=dict(color="purple", width=5))
fig.add_shape(type="line",
x0=6, y0=0, x1=6, y1=8,
line=dict(color="purple", width=5))
fig.add_shape(type="line",
x0=0, y0=3, x1=8, y1=3,
line=dict(color="purple", width=5))
fig.add_shape(type="line",
x0=0, y0=6, x1=8, y1=6,
line=dict(color="purple", width=5))
fig.add_shape(type="rect",
fillcolor="wheat",
opacity=0.75,
x0=2, y0=0, x1=3, y1=7,
line=dict(color="yellow", width=5))
fig.add_shape(type="rect",
fillcolor="wheat",
opacity=0.75,
x0=5, y0=0, x1=6, y1=7,
line=dict(color="yellow", width=5))
fig.add_shape(type="rect",
fillcolor="wheat",
opacity=0.75,
x0=0, y0=2, x1=7, y1=3,
line=dict(color="yellow", width=5))
fig.add_shape(type="rect",
fillcolor="wheat",
opacity=0.7,
x0=0, y0=2, x1=7, y1=3,
line=dict(color="yellow", width=5))
fig.add_shape(type="rect",
fillcolor="wheat",
opacity=0.7,
x0=0, y0=5, x1=7, y1=6,
line=dict(color="yellow", width=5))
fig.update_shapes(dict(xref='x', yref='y'))
fig.update_layout(
height=600,
width=600,
)