计算卷积带宽

计算卷积带宽#

假设输入数据 shape: \((b, c_i, h_i, w_i)\),卷积核为 \((c_o, c_i, h_k, w_k)\),padding 为 \((p_0, p_1, p_2, p_3)\),strides 为:\((s_h, s_w)\),则有:

\[\begin{split} \begin{align} h_o = \lfloor \dfrac{h_i + 2 p_0 - k_h}{s_h} + 1 \rfloor \\ w_o = \lfloor \dfrac{w_i + 2 p_1 - k_h}{s_h} + 1 \rfloor \end{align} \end{split}\]

备注

需要考虑一种约束:

\[\begin{split} \begin{cases} h_i + 2 p_0 \ge k_h \\ w_i + 2 p_1 \ge w_h \\ s_h \ge 1 \\ s_w \ge 1 \end{cases} \end{split}\]

这个条件可能会被忽略。

假设水平方向滑动 \(t_w\) 次,竖直方向滑动 \(t_h\) 次,则输入数据总计搬运:\(\text{inp\_load\_nbytes} = (t_ht_w) (k_h k_w)\),重叠区域面积为:\(\text{overlap} = (k_h k_w - s_h s_w)((t_h-1) (t_w-1))\),有效输入:\(((t_h - 1)*k_h + s_h)((t_w - 1)*k_w + s_w)\)

from dataclasses import dataclass

@dataclass
class Conv2dThroughput:
    k_h: int # 卷积核高度
    k_w: int # 卷积核宽度
    s_h: int # 卷积核纵向移动步长
    s_w: int # 卷积核横向移动步长
    t_h: int # 卷积核纵向移动次数
    t_w: int # 卷积核横向移动次数

    @property
    def inp_load_nbytes(self):
        """输入数据搬运带宽

        每次搬运数据面积为 self.k_h * self.k_w,
        总计搬运 self.t_h * self.t_w 次
        """
        return self.t_h * self.t_w * self.k_h * self.k_w

    @property
    def inp_overlap_nbytes(self):
        """输入数据搬运重叠带宽

        考虑横纵两个方向,同时搬运一次 overlap 面积是固定的,
        即 self.k_h * self.k_w - self.s_h * self.s_w,
        总计搬运次数为 (self.t_h - 1) * (self.t_w - 1)
        """
        overlap_block = self.k_h * self.k_w - self.s_h * self.s_w
        time_block = (self.t_h - 1) * (self.t_w - 1)
        return overlap_block * time_block

    @property
    def inp_valid_nbytes(self):
        """数据数据实际占用面积
        """
        block_h = (self.t_h - 1) * self.k_h + s_h
        block_w = (self.t_w - 1) * self.k_w + s_w
        return block_h * block_w
import plotly.graph_objects as go

fig = go.Figure()

# fig.add_trace(go.Scatter(
#     x=[1.5, 4.5],
#     y=[0.75, 0.75],
#     text=["Unfilled Rectangle", "Filled Rectangle"],
#     mode="text",
# ))

# 设置 axes 属性
fig.update_xaxes(range=[0, 8], showgrid=True, side="top")
fig.update_yaxes(range=[8, 0], autorange="reversed", showgrid=True)
# Add shapes

fig.add_shape(type="rect",
    x0=2, y0=0, x1=5, y1=3,
    line=dict(color="red", width=7),
    fillcolor="red",
    opacity=0.75
)
fig.add_shape(type="rect",
    x0=4, y0=0, x1=7, y1=3,
    line=dict(color="green", width=7),
    fillcolor="green",
    opacity=0.5
)
fig.add_shape(type="rect",
    x0=0, y0=2, x1=3, y1=5,
    line=dict(color="red", width=7),
    fillcolor="red",
    opacity=0.5
)
fig.add_shape(type="rect",
    x0=0, y0=4, x1=3, y1=7,
    line=dict(color="green", width=7),
    fillcolor="green",
    opacity=0.5
)
fig.add_shape(type="rect",
    x0=2, y0=2, x1=5, y1=5,
    line=dict(color="red", width=7),
    fillcolor="red",
    opacity=0.5
)
fig.add_shape(type="rect",
    x0=2, y0=4, x1=5, y1=7,
    line=dict(color="green", width=7),
    fillcolor="green",
    opacity=0.5
)
fig.add_shape(type="rect",
    x0=4, y0=2, x1=7, y1=5,
    line=dict(color="red", width=7),
    fillcolor="red",
    opacity=0.5
)
fig.add_shape(type="rect",
    x0=4, y0=4, x1=7, y1=7,
    line=dict(color="green", width=7),
    fillcolor="green",
    opacity=0.5
)
fig.add_shape(type="rect",
    x0=0, y0=0, x1=7, y1=7,
    line=dict(color="gray", width=7),
    fillcolor="wheat",
    opacity=0.2
)
fig.add_shape(type="rect",
    x0=0, y0=0, x1=3, y1=3,
    line=dict(color="LightSkyBlue", width=7),
    fillcolor="LightSkyBlue",
    opacity=0.5
)
fig.add_shape(type="line",
              x0=3, y0=0, x1=3, y1=8,
              line=dict(color="purple", width=5))
fig.add_shape(type="line",
              x0=6, y0=0, x1=6, y1=8,
              line=dict(color="purple", width=5))
fig.add_shape(type="line",
              x0=0, y0=3, x1=8, y1=3,
              line=dict(color="purple", width=5))
fig.add_shape(type="line",
              x0=0, y0=6, x1=8, y1=6,
              line=dict(color="purple", width=5))
fig.add_shape(type="rect",
              fillcolor="wheat",
              opacity=0.75,
              x0=2, y0=0, x1=3, y1=7,
              line=dict(color="yellow", width=5))
fig.add_shape(type="rect",
              fillcolor="wheat",
              opacity=0.75,
              x0=5, y0=0, x1=6, y1=7,
              line=dict(color="yellow", width=5))
fig.add_shape(type="rect",
              fillcolor="wheat",
              opacity=0.75,
              x0=0, y0=2, x1=7, y1=3,
              line=dict(color="yellow", width=5))
fig.add_shape(type="rect",
              fillcolor="wheat",
              opacity=0.7,
              x0=0, y0=2, x1=7, y1=3,
              line=dict(color="yellow", width=5))
fig.add_shape(type="rect",
              fillcolor="wheat",
              opacity=0.7,
              x0=0, y0=5, x1=7, y1=6,
              line=dict(color="yellow", width=5))
fig.update_shapes(dict(xref='x', yref='y'))
fig.update_layout(
    height=600, 
    width=600,
)