In [1]:
# Useful while developing; ensures compiler.py gets reloaded.
%load_ext autoreload
%autoreload 2
In [2]:
from subprocess import run, CalledProcessError
from tempfile import NamedTemporaryFile
import ctypes
from from_relay import convert_model
import tvm
from tvm import relay
from tvm.relay.testing.resnet import get_workload
import numpy as np
from time import perf_counter
from lib.WeightStationarySystolicArray import WeightStationarySystolicArray
from lib.Conv2D import Conv2D
from lib.Dense import Dense
from lib.utils import *
import itertools
from pandas import DataFrame
import json
from compiler import Compiler
import os
In [3]:
module, params = get_workload(layout="NHWC", image_shape=(224, 224, 3))
In [4]:
layers = convert_model(module)
op not implemented: nn.batch_norm
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.max_pool2d
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: add
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: add
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: add
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: add
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: add
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: add
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: add
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: add
op not implemented: nn.batch_norm
op not implemented: nn.relu
op not implemented: nn.global_avg_pool2d
op not implemented: nn.batch_flatten
op not implemented: nn.bias_add
op not implemented: nn.softmax
In [5]:
SA_SIZE = [2 ** i for i in range(3, 9)]
GB_DEPTH = [2 ** i for i in range(15, 18)]
# GB_DEPTH = [ 98304, ]
ACC_DEPTH = [2 ** i for i in range(11, 14)]
# ACC_DEPTH = [ 4096, ]

records = []

for sa_size, gb_depth, acc_depth in itertools.product(SA_SIZE, GB_DEPTH, ACC_DEPTH):
    TPUv1 = WeightStationarySystolicArray(
        name="Google TPUv1",
        h=sa_size,
        w=sa_size,
        freq=700e6,
        wld_bw=30 * (1024 ** 3),
        act_bw=10 * (1024 ** 3),
        gb_depth=gb_depth,
        acc_depth=acc_depth,
    )
    trace_fd = None
    TPUv1.reset_simulation()

    try:
        t_layers = TPUv1.simulate_network(layers, debug=False, trace_fd=trace_fd)
        t_sim = TPUv1.get_simulation_time("s")
        fps = 1.0 / t_sim[0]
        print(
            "{}x{} array, GB depth={}, ACC depth={} ---> {:.2f} FPS".format(
                sa_size, sa_size, gb_depth, acc_depth, fps
            )
        )

        records.append(
            {
                "rows": sa_size,
                "columns": sa_size,
                "global buffer depth": gb_depth,
                "accumulator depth": acc_depth,
                "frames per second": fps,
            }
        )
    except Exception as e:
        print(
            "{}x{} array, GB depth={}, ACC depth={} ---> {}".format(
                sa_size, sa_size, gb_depth, acc_depth, e
            )
        )
df = DataFrame(data=records)
8x8 array, GB depth=32768, ACC depth=2048 ---> 22.51 FPS
8x8 array, GB depth=32768, ACC depth=4096 ---> 22.38 FPS
8x8 array, GB depth=32768, ACC depth=8192 ---> 23.91 FPS
8x8 array, GB depth=65536, ACC depth=2048 ---> 22.48 FPS
8x8 array, GB depth=65536, ACC depth=4096 ---> 22.40 FPS
8x8 array, GB depth=65536, ACC depth=8192 ---> 23.87 FPS
8x8 array, GB depth=131072, ACC depth=2048 ---> 22.51 FPS
8x8 array, GB depth=131072, ACC depth=4096 ---> 22.37 FPS
8x8 array, GB depth=131072, ACC depth=8192 ---> 23.88 FPS
16x16 array, GB depth=32768, ACC depth=2048 ---> 88.01 FPS
16x16 array, GB depth=32768, ACC depth=4096 ---> 87.01 FPS
16x16 array, GB depth=32768, ACC depth=8192 ---> 92.42 FPS
16x16 array, GB depth=65536, ACC depth=2048 ---> 88.89 FPS
16x16 array, GB depth=65536, ACC depth=4096 ---> 87.86 FPS
16x16 array, GB depth=65536, ACC depth=8192 ---> 92.89 FPS
16x16 array, GB depth=131072, ACC depth=2048 ---> 89.03 FPS
16x16 array, GB depth=131072, ACC depth=4096 ---> 88.17 FPS
16x16 array, GB depth=131072, ACC depth=8192 ---> 94.17 FPS
32x32 array, GB depth=32768, ACC depth=2048 ---> 270.30 FPS
32x32 array, GB depth=32768, ACC depth=4096 ---> 253.27 FPS
32x32 array, GB depth=32768, ACC depth=8192 ---> 252.93 FPS
32x32 array, GB depth=65536, ACC depth=2048 ---> 281.19 FPS
32x32 array, GB depth=65536, ACC depth=4096 ---> 270.91 FPS
32x32 array, GB depth=65536, ACC depth=8192 ---> 278.45 FPS
32x32 array, GB depth=131072, ACC depth=2048 ---> 288.19 FPS
32x32 array, GB depth=131072, ACC depth=4096 ---> 277.07 FPS
32x32 array, GB depth=131072, ACC depth=8192 ---> 280.24 FPS
64x64 array, GB depth=32768, ACC depth=2048 ---> Accumulator overflow.
64x64 array, GB depth=32768, ACC depth=4096 ---> Accumulator overflow.
64x64 array, GB depth=32768, ACC depth=8192 ---> 477.31 FPS
64x64 array, GB depth=65536, ACC depth=2048 ---> Accumulator overflow.
64x64 array, GB depth=65536, ACC depth=4096 ---> Accumulator overflow.
64x64 array, GB depth=65536, ACC depth=8192 ---> 477.31 FPS
64x64 array, GB depth=131072, ACC depth=2048 ---> Accumulator overflow.
64x64 array, GB depth=131072, ACC depth=4096 ---> Accumulator overflow.
64x64 array, GB depth=131072, ACC depth=8192 ---> 477.31 FPS
128x128 array, GB depth=32768, ACC depth=2048 ---> Accumulator overflow.
128x128 array, GB depth=32768, ACC depth=4096 ---> Accumulator overflow.
128x128 array, GB depth=32768, ACC depth=8192 ---> 499.77 FPS
128x128 array, GB depth=65536, ACC depth=2048 ---> Accumulator overflow.
128x128 array, GB depth=65536, ACC depth=4096 ---> Accumulator overflow.
128x128 array, GB depth=65536, ACC depth=8192 ---> 499.77 FPS
128x128 array, GB depth=131072, ACC depth=2048 ---> Accumulator overflow.
128x128 array, GB depth=131072, ACC depth=4096 ---> Accumulator overflow.
128x128 array, GB depth=131072, ACC depth=8192 ---> 499.77 FPS
256x256 array, GB depth=32768, ACC depth=2048 ---> Accumulator overflow.
256x256 array, GB depth=32768, ACC depth=4096 ---> Accumulator overflow.
256x256 array, GB depth=32768, ACC depth=8192 ---> Accumulator overflow.
256x256 array, GB depth=65536, ACC depth=2048 ---> Accumulator overflow.
256x256 array, GB depth=65536, ACC depth=4096 ---> Accumulator overflow.
256x256 array, GB depth=65536, ACC depth=8192 ---> Accumulator overflow.
256x256 array, GB depth=131072, ACC depth=2048 ---> Accumulator overflow.
256x256 array, GB depth=131072, ACC depth=4096 ---> Accumulator overflow.
256x256 array, GB depth=131072, ACC depth=8192 ---> Accumulator overflow.
In [6]:
# Row which achieves the max FPS
max_row = df.loc[df["frames per second"].idxmax()]
In [7]:
def to_JSON(rows, cols, global_buffer_depth, accumulator_depth):
    return [
        {
            "id": 0,
            "name": "systolic_array",
            "atom": "bsg_systolic_array_weight_stationary",
            "dtype": "int8",
            "rows": int(rows),
            "cols": int(cols),
            "global_buffer_depth": int(global_buffer_depth),
            "accumulator_depth": int(accumulator_depth),
        }
    ]
In [8]:
hardware_json = to_JSON(
    max_row["rows"],
    max_row["columns"],
    max_row["global buffer depth"],
    max_row["accumulator depth"],
)
hardware_json
Out[8]:
[{'id': 0,
  'name': 'systolic_array',
  'atom': 'bsg_systolic_array_weight_stationary',
  'dtype': 'int8',
  'rows': 128,
  'cols': 128,
  'global_buffer_depth': 32768,
  'accumulator_depth': 8192}]
In [9]:
compiled = Compiler(module)

Write the result to a file and format it with clang-format.

In [10]:
c_file = NamedTemporaryFile(suffix=".c")
c_file.write(compiled.get_file().encode())
c_file.flush()

assert run(["clang-format", "-i", c_file.name]).returncode == 0

The resulting C file:

In [11]:
print(open(c_file.name).read())
extern void rtml_systolic_array_weight_stationary_fc(
    int hardware_id, float *out, float *activations, float *weights,
    int input_vector_size, int output_vector_size, int batch);
extern void rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
    int hardware_id, float *out, float *activations, float *weights, int h,
    int w, int kernel_h, int kernel_w, int in_channels, int out_channels,
    int stride_h, int stride_w);
extern void batchNormInference(float *X, float *Y, int N, int H, int W, int C,
                               float *gamma, float *beta, float *mu, float *var,
                               float epsilon);
extern void softmax1D(float *X, float *Y, int N);
extern void relu(float *X, float *Y, int N, int H, int W, int C);
extern void globalAvgPool(float *X, float *Y, int N, int H, int W, int C);
extern void add_with_broadcasting(float *out, float *a, float *b,
                                  int *out_shape, int out_ndims, int *a_shape,
                                  int a_ndims, int *b_shape, int b_ndims);
extern void add(float *out, float *a, float *b);
extern void maxpool2D3x3_resnet18_op6(float *X, float *Y);
extern void zero_pad_nhwc(float *out, float *in, int h, int w, int c,
                          int pad_north, int pad_east, int pad_south,
                          int pad_west);
float b93866198280016[1][224][224][3];
float b93866198280128prepadding[1][112][112][64];
float b93866198280128[1][112][112][64];
float b93866198280688[1][112][112][64];
float b93866198280880[1][112][112][64];
float b93866198281168[1][56][56][64];
float b93866198281728[1][56][56][64];
float b93866198281920[1][56][56][64];
float b93866198282208prepadding[1][56][56][64];
float b93866198282208[1][56][56][64];
float b93866198282768[1][56][56][64];
float b93866198282960[1][56][56][64];
float b93866198283248prepadding[1][56][56][64];
float b93866198283248[1][56][56][64];
float b93866198283536prepadding[1][56][56][64];
float b93866198283536[1][56][56][64];
float b93866198283824[1][56][56][64];
float b93866198284384[1][56][56][64];
float b93866198284576[1][56][56][64];
float b93866198284864prepadding[1][56][56][64];
float b93866198284864[1][56][56][64];
float b93866198285424[1][56][56][64];
float b93866198285616[1][56][56][64];
float b93866198285904prepadding[1][56][56][64];
float b93866198285904[1][56][56][64];
float b93866198286192[1][56][56][64];
float b93866198286752[1][56][56][64];
float b93866198286944[1][56][56][64];
float b93866198287232prepadding[1][28][28][128];
float b93866198287232[1][28][28][128];
float b93866198287792[1][28][28][128];
float b93866198287984[1][28][28][128];
float b93866198288272prepadding[1][28][28][128];
float b93866198288272[1][28][28][128];
float b93866198288560prepadding[1][28][28][128];
float b93866198288560[1][28][28][128];
float b93866198288848[1][28][28][128];
float b93866198289408[1][28][28][128];
float b93866198289600[1][28][28][128];
float b93866198289888prepadding[1][28][28][128];
float b93866198289888[1][28][28][128];
float b93866198290448[1][28][28][128];
float b93866198290640[1][28][28][128];
float b93866198290928prepadding[1][28][28][128];
float b93866198290928[1][28][28][128];
float b93866198291216[1][28][28][128];
float b93866198291776[1][28][28][128];
float b93866198291968[1][28][28][128];
float b93866198292256prepadding[1][14][14][256];
float b93866198292256[1][14][14][256];
float b93866198292816[1][14][14][256];
float b93866198293008[1][14][14][256];
float b93866198293296prepadding[1][14][14][256];
float b93866198293296[1][14][14][256];
float b93866198293584prepadding[1][14][14][256];
float b93866198293584[1][14][14][256];
float b93866198293872[1][14][14][256];
float b93866198294432[1][14][14][256];
float b93866198274688[1][14][14][256];
float b93866198274976prepadding[1][14][14][256];
float b93866198274976[1][14][14][256];
float b93866198275536[1][14][14][256];
float b93866198275728[1][14][14][256];
float b93866198297328prepadding[1][14][14][256];
float b93866198297328[1][14][14][256];
float b93866198297696[1][14][14][256];
float b93866198298256[1][14][14][256];
float b93866198298448[1][14][14][256];
float b93866198298656prepadding[1][7][7][512];
float b93866198298656[1][7][7][512];
float b93866198299296[1][7][7][512];
float b93866198299488[1][7][7][512];
float b93866198299776prepadding[1][7][7][512];
float b93866198299776[1][7][7][512];
float b93866198300064prepadding[1][7][7][512];
float b93866198300064[1][7][7][512];
float b93866198300352[1][7][7][512];
float b93866198300912[1][7][7][512];
float b93866198301104[1][7][7][512];
float b93866198301392prepadding[1][7][7][512];
float b93866198301392[1][7][7][512];
float b93866198301952[1][7][7][512];
float b93866198302144[1][7][7][512];
float b93866198302432prepadding[1][7][7][512];
float b93866198302432[1][7][7][512];
float b93866198302720[1][7][7][512];
float b93866198303280[1][7][7][512];
float b93866198303472[1][7][7][512];
float b93866198303760[1][1][1][512];
float b93866198304336[1][1000];
float b93866198304624[1][1000];

void compiled(
    float *out, float *data, float *bn_data_gamma, float *bn_data_beta,
    float *bn_data_moving_mean, float *bn_data_moving_var, float *conv0_weight,
    float *bn0_gamma, float *bn0_beta, float *bn0_moving_mean,
    float *bn0_moving_var, float *stage1_unit1_bn1_gamma,
    float *stage1_unit1_bn1_beta, float *stage1_unit1_bn1_moving_mean,
    float *stage1_unit1_bn1_moving_var, float *stage1_unit1_conv1_weight,
    float *stage1_unit1_bn2_gamma, float *stage1_unit1_bn2_beta,
    float *stage1_unit1_bn2_moving_mean, float *stage1_unit1_bn2_moving_var,
    float *stage1_unit1_conv2_weight, float *stage1_unit1_sc_weight,
    float *stage1_unit2_bn1_gamma, float *stage1_unit2_bn1_beta,
    float *stage1_unit2_bn1_moving_mean, float *stage1_unit2_bn1_moving_var,
    float *stage1_unit2_conv1_weight, float *stage1_unit2_bn2_gamma,
    float *stage1_unit2_bn2_beta, float *stage1_unit2_bn2_moving_mean,
    float *stage1_unit2_bn2_moving_var, float *stage1_unit2_conv2_weight,
    float *stage2_unit1_bn1_gamma, float *stage2_unit1_bn1_beta,
    float *stage2_unit1_bn1_moving_mean, float *stage2_unit1_bn1_moving_var,
    float *stage2_unit1_conv1_weight, float *stage2_unit1_bn2_gamma,
    float *stage2_unit1_bn2_beta, float *stage2_unit1_bn2_moving_mean,
    float *stage2_unit1_bn2_moving_var, float *stage2_unit1_conv2_weight,
    float *stage2_unit1_sc_weight, float *stage2_unit2_bn1_gamma,
    float *stage2_unit2_bn1_beta, float *stage2_unit2_bn1_moving_mean,
    float *stage2_unit2_bn1_moving_var, float *stage2_unit2_conv1_weight,
    float *stage2_unit2_bn2_gamma, float *stage2_unit2_bn2_beta,
    float *stage2_unit2_bn2_moving_mean, float *stage2_unit2_bn2_moving_var,
    float *stage2_unit2_conv2_weight, float *stage3_unit1_bn1_gamma,
    float *stage3_unit1_bn1_beta, float *stage3_unit1_bn1_moving_mean,
    float *stage3_unit1_bn1_moving_var, float *stage3_unit1_conv1_weight,
    float *stage3_unit1_bn2_gamma, float *stage3_unit1_bn2_beta,
    float *stage3_unit1_bn2_moving_mean, float *stage3_unit1_bn2_moving_var,
    float *stage3_unit1_conv2_weight, float *stage3_unit1_sc_weight,
    float *stage3_unit2_bn1_gamma, float *stage3_unit2_bn1_beta,
    float *stage3_unit2_bn1_moving_mean, float *stage3_unit2_bn1_moving_var,
    float *stage3_unit2_conv1_weight, float *stage3_unit2_bn2_gamma,
    float *stage3_unit2_bn2_beta, float *stage3_unit2_bn2_moving_mean,
    float *stage3_unit2_bn2_moving_var, float *stage3_unit2_conv2_weight,
    float *stage4_unit1_bn1_gamma, float *stage4_unit1_bn1_beta,
    float *stage4_unit1_bn1_moving_mean, float *stage4_unit1_bn1_moving_var,
    float *stage4_unit1_conv1_weight, float *stage4_unit1_bn2_gamma,
    float *stage4_unit1_bn2_beta, float *stage4_unit1_bn2_moving_mean,
    float *stage4_unit1_bn2_moving_var, float *stage4_unit1_conv2_weight,
    float *stage4_unit1_sc_weight, float *stage4_unit2_bn1_gamma,
    float *stage4_unit2_bn1_beta, float *stage4_unit2_bn1_moving_mean,
    float *stage4_unit2_bn1_moving_var, float *stage4_unit2_conv1_weight,
    float *stage4_unit2_bn2_gamma, float *stage4_unit2_bn2_beta,
    float *stage4_unit2_bn2_moving_mean, float *stage4_unit2_bn2_moving_var,
    float *stage4_unit2_conv2_weight, float *bn1_gamma, float *bn1_beta,
    float *bn1_moving_mean, float *bn1_moving_var, float *fc1_weight,
    float *fc1_bias) {

  batchNormInference((float *)data, (float *)b93866198280016, 1, 224, 224, 3,
                     bn_data_gamma, bn_data_beta, bn_data_moving_mean,
                     bn_data_moving_var, 0.00002);

  zero_pad_nhwc((float *)b93866198280128prepadding, (float *)b93866198280016,
                224, 224, 3, 3, 3, 3, 3);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198280128, (float *)b93866198280128prepadding,
      conv0_weight, 230, 230, 7, 7, 3, 64, 2, 2);

  batchNormInference((float *)b93866198280128, (float *)b93866198280688, 1, 112,
                     112, 64, bn0_gamma, bn0_beta, bn0_moving_mean,
                     bn0_moving_var, 0.00002);

  relu((float *)b93866198280688, (float *)b93866198280880, 1, 112, 112, 64);

  maxpool2D3x3_resnet18_op6((float *)b93866198280880, (float *)b93866198281168);

  batchNormInference((float *)b93866198281168, (float *)b93866198281728, 1, 56,
                     56, 64, stage1_unit1_bn1_gamma, stage1_unit1_bn1_beta,
                     stage1_unit1_bn1_moving_mean, stage1_unit1_bn1_moving_var,
                     0.00002);

  relu((float *)b93866198281728, (float *)b93866198281920, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866198282208prepadding, (float *)b93866198281920,
                56, 56, 64, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198282208, (float *)b93866198282208prepadding,
      stage1_unit1_conv1_weight, 58, 58, 3, 3, 64, 64, 1, 1);

  batchNormInference((float *)b93866198282208, (float *)b93866198282768, 1, 56,
                     56, 64, stage1_unit1_bn2_gamma, stage1_unit1_bn2_beta,
                     stage1_unit1_bn2_moving_mean, stage1_unit1_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198282768, (float *)b93866198282960, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866198283248prepadding, (float *)b93866198282960,
                56, 56, 64, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198283248, (float *)b93866198283248prepadding,
      stage1_unit1_conv2_weight, 58, 58, 3, 3, 64, 64, 1, 1);

  zero_pad_nhwc((float *)b93866198283536prepadding, (float *)b93866198281920,
                56, 56, 64, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198283536, (float *)b93866198283536prepadding,
      stage1_unit1_sc_weight, 56, 56, 1, 1, 64, 64, 1, 1);

  int b93866198283824_out_shape[4] = {1, 56, 56, 64};
  int b93866198283824_a_shape[4] = {1, 56, 56, 64};
  int b93866198283824_b_shape[4] = {1, 56, 56, 64};
  add_with_broadcasting(
      (float *)b93866198283824, (float *)b93866198283248,
      (float *)b93866198283536, (int *)b93866198283824_out_shape, 4,
      (int *)b93866198283824_a_shape, 4, (int *)b93866198283824_b_shape, 4);

  batchNormInference((float *)b93866198283824, (float *)b93866198284384, 1, 56,
                     56, 64, stage1_unit2_bn1_gamma, stage1_unit2_bn1_beta,
                     stage1_unit2_bn1_moving_mean, stage1_unit2_bn1_moving_var,
                     0.00002);

  relu((float *)b93866198284384, (float *)b93866198284576, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866198284864prepadding, (float *)b93866198284576,
                56, 56, 64, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198284864, (float *)b93866198284864prepadding,
      stage1_unit2_conv1_weight, 58, 58, 3, 3, 64, 64, 1, 1);

  batchNormInference((float *)b93866198284864, (float *)b93866198285424, 1, 56,
                     56, 64, stage1_unit2_bn2_gamma, stage1_unit2_bn2_beta,
                     stage1_unit2_bn2_moving_mean, stage1_unit2_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198285424, (float *)b93866198285616, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866198285904prepadding, (float *)b93866198285616,
                56, 56, 64, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198285904, (float *)b93866198285904prepadding,
      stage1_unit2_conv2_weight, 58, 58, 3, 3, 64, 64, 1, 1);

  int b93866198286192_out_shape[4] = {1, 56, 56, 64};
  int b93866198286192_a_shape[4] = {1, 56, 56, 64};
  int b93866198286192_b_shape[4] = {1, 56, 56, 64};
  add_with_broadcasting(
      (float *)b93866198286192, (float *)b93866198285904,
      (float *)b93866198283824, (int *)b93866198286192_out_shape, 4,
      (int *)b93866198286192_a_shape, 4, (int *)b93866198286192_b_shape, 4);

  batchNormInference((float *)b93866198286192, (float *)b93866198286752, 1, 56,
                     56, 64, stage2_unit1_bn1_gamma, stage2_unit1_bn1_beta,
                     stage2_unit1_bn1_moving_mean, stage2_unit1_bn1_moving_var,
                     0.00002);

  relu((float *)b93866198286752, (float *)b93866198286944, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866198287232prepadding, (float *)b93866198286944,
                56, 56, 64, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198287232, (float *)b93866198287232prepadding,
      stage2_unit1_conv1_weight, 58, 58, 3, 3, 64, 128, 2, 2);

  batchNormInference((float *)b93866198287232, (float *)b93866198287792, 1, 28,
                     28, 128, stage2_unit1_bn2_gamma, stage2_unit1_bn2_beta,
                     stage2_unit1_bn2_moving_mean, stage2_unit1_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198287792, (float *)b93866198287984, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866198288272prepadding, (float *)b93866198287984,
                28, 28, 128, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198288272, (float *)b93866198288272prepadding,
      stage2_unit1_conv2_weight, 30, 30, 3, 3, 128, 128, 1, 1);

  zero_pad_nhwc((float *)b93866198288560prepadding, (float *)b93866198286944,
                56, 56, 64, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198288560, (float *)b93866198288560prepadding,
      stage2_unit1_sc_weight, 56, 56, 1, 1, 64, 128, 2, 2);

  int b93866198288848_out_shape[4] = {1, 28, 28, 128};
  int b93866198288848_a_shape[4] = {1, 28, 28, 128};
  int b93866198288848_b_shape[4] = {1, 28, 28, 128};
  add_with_broadcasting(
      (float *)b93866198288848, (float *)b93866198288272,
      (float *)b93866198288560, (int *)b93866198288848_out_shape, 4,
      (int *)b93866198288848_a_shape, 4, (int *)b93866198288848_b_shape, 4);

  batchNormInference((float *)b93866198288848, (float *)b93866198289408, 1, 28,
                     28, 128, stage2_unit2_bn1_gamma, stage2_unit2_bn1_beta,
                     stage2_unit2_bn1_moving_mean, stage2_unit2_bn1_moving_var,
                     0.00002);

  relu((float *)b93866198289408, (float *)b93866198289600, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866198289888prepadding, (float *)b93866198289600,
                28, 28, 128, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198289888, (float *)b93866198289888prepadding,
      stage2_unit2_conv1_weight, 30, 30, 3, 3, 128, 128, 1, 1);

  batchNormInference((float *)b93866198289888, (float *)b93866198290448, 1, 28,
                     28, 128, stage2_unit2_bn2_gamma, stage2_unit2_bn2_beta,
                     stage2_unit2_bn2_moving_mean, stage2_unit2_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198290448, (float *)b93866198290640, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866198290928prepadding, (float *)b93866198290640,
                28, 28, 128, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198290928, (float *)b93866198290928prepadding,
      stage2_unit2_conv2_weight, 30, 30, 3, 3, 128, 128, 1, 1);

  int b93866198291216_out_shape[4] = {1, 28, 28, 128};
  int b93866198291216_a_shape[4] = {1, 28, 28, 128};
  int b93866198291216_b_shape[4] = {1, 28, 28, 128};
  add_with_broadcasting(
      (float *)b93866198291216, (float *)b93866198290928,
      (float *)b93866198288848, (int *)b93866198291216_out_shape, 4,
      (int *)b93866198291216_a_shape, 4, (int *)b93866198291216_b_shape, 4);

  batchNormInference((float *)b93866198291216, (float *)b93866198291776, 1, 28,
                     28, 128, stage3_unit1_bn1_gamma, stage3_unit1_bn1_beta,
                     stage3_unit1_bn1_moving_mean, stage3_unit1_bn1_moving_var,
                     0.00002);

  relu((float *)b93866198291776, (float *)b93866198291968, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866198292256prepadding, (float *)b93866198291968,
                28, 28, 128, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198292256, (float *)b93866198292256prepadding,
      stage3_unit1_conv1_weight, 30, 30, 3, 3, 128, 256, 2, 2);

  batchNormInference((float *)b93866198292256, (float *)b93866198292816, 1, 14,
                     14, 256, stage3_unit1_bn2_gamma, stage3_unit1_bn2_beta,
                     stage3_unit1_bn2_moving_mean, stage3_unit1_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198292816, (float *)b93866198293008, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866198293296prepadding, (float *)b93866198293008,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198293296, (float *)b93866198293296prepadding,
      stage3_unit1_conv2_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  zero_pad_nhwc((float *)b93866198293584prepadding, (float *)b93866198291968,
                28, 28, 128, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198293584, (float *)b93866198293584prepadding,
      stage3_unit1_sc_weight, 28, 28, 1, 1, 128, 256, 2, 2);

  int b93866198293872_out_shape[4] = {1, 14, 14, 256};
  int b93866198293872_a_shape[4] = {1, 14, 14, 256};
  int b93866198293872_b_shape[4] = {1, 14, 14, 256};
  add_with_broadcasting(
      (float *)b93866198293872, (float *)b93866198293296,
      (float *)b93866198293584, (int *)b93866198293872_out_shape, 4,
      (int *)b93866198293872_a_shape, 4, (int *)b93866198293872_b_shape, 4);

  batchNormInference((float *)b93866198293872, (float *)b93866198294432, 1, 14,
                     14, 256, stage3_unit2_bn1_gamma, stage3_unit2_bn1_beta,
                     stage3_unit2_bn1_moving_mean, stage3_unit2_bn1_moving_var,
                     0.00002);

  relu((float *)b93866198294432, (float *)b93866198274688, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866198274976prepadding, (float *)b93866198274688,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198274976, (float *)b93866198274976prepadding,
      stage3_unit2_conv1_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  batchNormInference((float *)b93866198274976, (float *)b93866198275536, 1, 14,
                     14, 256, stage3_unit2_bn2_gamma, stage3_unit2_bn2_beta,
                     stage3_unit2_bn2_moving_mean, stage3_unit2_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198275536, (float *)b93866198275728, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866198297328prepadding, (float *)b93866198275728,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198297328, (float *)b93866198297328prepadding,
      stage3_unit2_conv2_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  int b93866198297696_out_shape[4] = {1, 14, 14, 256};
  int b93866198297696_a_shape[4] = {1, 14, 14, 256};
  int b93866198297696_b_shape[4] = {1, 14, 14, 256};
  add_with_broadcasting(
      (float *)b93866198297696, (float *)b93866198297328,
      (float *)b93866198293872, (int *)b93866198297696_out_shape, 4,
      (int *)b93866198297696_a_shape, 4, (int *)b93866198297696_b_shape, 4);

  batchNormInference((float *)b93866198297696, (float *)b93866198298256, 1, 14,
                     14, 256, stage4_unit1_bn1_gamma, stage4_unit1_bn1_beta,
                     stage4_unit1_bn1_moving_mean, stage4_unit1_bn1_moving_var,
                     0.00002);

  relu((float *)b93866198298256, (float *)b93866198298448, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866198298656prepadding, (float *)b93866198298448,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198298656, (float *)b93866198298656prepadding,
      stage4_unit1_conv1_weight, 16, 16, 3, 3, 256, 512, 2, 2);

  batchNormInference((float *)b93866198298656, (float *)b93866198299296, 1, 7,
                     7, 512, stage4_unit1_bn2_gamma, stage4_unit1_bn2_beta,
                     stage4_unit1_bn2_moving_mean, stage4_unit1_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198299296, (float *)b93866198299488, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866198299776prepadding, (float *)b93866198299488, 7,
                7, 512, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198299776, (float *)b93866198299776prepadding,
      stage4_unit1_conv2_weight, 9, 9, 3, 3, 512, 512, 1, 1);

  zero_pad_nhwc((float *)b93866198300064prepadding, (float *)b93866198298448,
                14, 14, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198300064, (float *)b93866198300064prepadding,
      stage4_unit1_sc_weight, 14, 14, 1, 1, 256, 512, 2, 2);

  int b93866198300352_out_shape[4] = {1, 7, 7, 512};
  int b93866198300352_a_shape[4] = {1, 7, 7, 512};
  int b93866198300352_b_shape[4] = {1, 7, 7, 512};
  add_with_broadcasting(
      (float *)b93866198300352, (float *)b93866198299776,
      (float *)b93866198300064, (int *)b93866198300352_out_shape, 4,
      (int *)b93866198300352_a_shape, 4, (int *)b93866198300352_b_shape, 4);

  batchNormInference((float *)b93866198300352, (float *)b93866198300912, 1, 7,
                     7, 512, stage4_unit2_bn1_gamma, stage4_unit2_bn1_beta,
                     stage4_unit2_bn1_moving_mean, stage4_unit2_bn1_moving_var,
                     0.00002);

  relu((float *)b93866198300912, (float *)b93866198301104, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866198301392prepadding, (float *)b93866198301104, 7,
                7, 512, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198301392, (float *)b93866198301392prepadding,
      stage4_unit2_conv1_weight, 9, 9, 3, 3, 512, 512, 1, 1);

  batchNormInference((float *)b93866198301392, (float *)b93866198301952, 1, 7,
                     7, 512, stage4_unit2_bn2_gamma, stage4_unit2_bn2_beta,
                     stage4_unit2_bn2_moving_mean, stage4_unit2_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198301952, (float *)b93866198302144, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866198302432prepadding, (float *)b93866198302144, 7,
                7, 512, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198302432, (float *)b93866198302432prepadding,
      stage4_unit2_conv2_weight, 9, 9, 3, 3, 512, 512, 1, 1);

  int b93866198302720_out_shape[4] = {1, 7, 7, 512};
  int b93866198302720_a_shape[4] = {1, 7, 7, 512};
  int b93866198302720_b_shape[4] = {1, 7, 7, 512};
  add_with_broadcasting(
      (float *)b93866198302720, (float *)b93866198302432,
      (float *)b93866198300352, (int *)b93866198302720_out_shape, 4,
      (int *)b93866198302720_a_shape, 4, (int *)b93866198302720_b_shape, 4);

  batchNormInference((float *)b93866198302720, (float *)b93866198303280, 1, 7,
                     7, 512, bn1_gamma, bn1_beta, bn1_moving_mean,
                     bn1_moving_var, 0.00002);

  relu((float *)b93866198303280, (float *)b93866198303472, 1, 7, 7, 512);

  globalAvgPool((float *)b93866198303472, (float *)b93866198303760, 1, 7, 7,
                512);

  rtml_systolic_array_weight_stationary_fc(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866198304336, (float *)b93866198303760, fc1_weight, 512, 1000,
      1);

  int b93866198304624_out_shape[2] = {1, 1000};
  int b93866198304624_a_shape[2] = {1, 1000};
  int b93866198304624_b_shape[1] = {1000};
  add_with_broadcasting((float *)b93866198304624, (float *)b93866198304336,
                        (float *)fc1_bias, (int *)b93866198304624_out_shape, 2,
                        (int *)b93866198304624_a_shape, 2,
                        (int *)b93866198304624_b_shape, 1);

  softmax1D((float *)b93866198304624, (float *)out, 1000);
}

Compile the file with gcc.

In [12]:
o_file = NamedTemporaryFile(suffix=".o")
result = run(
    ["gcc", "-Wall", "-Werror", "-fpic", "-g", "-o", o_file.name, "-c", c_file.name],
    capture_output=True,
)
assert result.returncode == 0, result.stderr.decode()

Link into shared library.

In [13]:
# First compile "ops.c" file into a .o
ops_o_file = NamedTemporaryFile(suffix=".o")
result = run(
    [
        "gcc",
        "-lm",
        "-Wall",
        "-Werror",
        "-fpic",
        "-g",
        "-o",
        ops_o_file.name,
        "-c",
        "ops.c",
    ],
    capture_output=True,
)
assert result.returncode == 0, result.stderr.decode()

# Then compile the .so
lib_file = NamedTemporaryFile(suffix=".so")
result = run(
    ["gcc", "-shared", "-o", lib_file.name, o_file.name, ops_o_file.name],
    capture_output=True,
)
assert result.returncode == 0, result.stderr.decode()

We can list the library's symbols with nm and search for our function:

In [14]:
print(
    [
        line
        for line in run(["nm", lib_file.name], capture_output=True)
        .stdout.decode()
        .splitlines()
        if "compiled" in line
    ][0]
)
00000000000031f5 T compiled
In [15]:
func = ctypes.CDLL(lib_file.name)["compiled"]
# Not necessary; just tells Python that the function doesn't return anything.
func.restype = None

Now, we'll construct the inputs to our now-compiled model, in the order expected by the C function representing our model. We'll store them in a list:

In [16]:
input_ndarrays = []

First, we reserve space for the output of the model:

In [17]:
out = np.zeros(
    [int(v) for v in compiled._module["main"].body.checked_type.shape], dtype="float32"
)
input_ndarrays.append(out)

Next, the input "image" data (which is currently just random data):

In [18]:
data = np.random.rand(
    *[int(v) for v in compiled._module["main"].params[0].checked_type.shape]
).astype("float32")
input_ndarrays.append(data)

Finally, the parameters, in the order specified by the original Relay module's main function. The params object conveniently already contains randomly generated params, so we just get the underlying ndarrays from those params. We also check that the params are C-contiguous and aligned; otherwise, I suspect some weird stuff would happen when we run this.

In [19]:
for var in compiled._module["main"].params[1:]:
    ndarray_param = params[var.name_hint].asnumpy()
    assert ndarray_param.flags["C_CONTIGUOUS"]
    assert ndarray_param.flags["ALIGNED"]
    input_ndarrays.append(ndarray_param)

Now we're ready to run the compiled function! We convert each ndarray input to a float* using ctypes. This may take a while to run---it's a completely unoptimized version of Resnet18 running on CPU, after all!

In [20]:
time = perf_counter()
func(
    *[
        ndarray.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
        for ndarray in input_ndarrays
    ]
)
print(f"Running took {perf_counter() - time} seconds!")
Running took 16.89182720001554 seconds!

Taking a look at the result, we see something that isn't numerically surprising: every value is about 1/1000. This is generally due to the fact that we're not using real data and parameters. There's no way to judge if our model is correct yet, until I go through the work of getting real trained Resnet18 parameters plugged in.

In [21]:
out
Out[21]:
array([[0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
        0.001]], dtype=float32)

We can compare the result to the result from running the same model through Relay. Again, this doesn't tell us anything at the moment, as we're not using real parameters or data.

In [22]:
assert np.allclose(
    relay.create_executor(mod=module).evaluate()(data, **params).asnumpy(),
    out,
    atol=1e-6,
)
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
conv2d NHWC layout is not optimized for x86 with autotvm.
Cannot find config for target=llvm -keys=cpu -link-params=0, workload=('dense_nopack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32'). A fallback configuration is used, which may bring great performance regression.

Now we'll run the generated hardware through Scott's hardware-generating scripts. We begin by writing the hardware design JSON to a file:

In [23]:
hardware_json_file = NamedTemporaryFile(suffix=".json", mode="w")
json.dump(hardware_json, hardware_json_file)
hardware_json_file.flush()

We then run his Python script:

In [24]:
try:
    run(
        ["./bsg_ml_atoms/scripts/chip_gen/main.py", hardware_json_file.name],
        check=True,
        capture_output=True,
    )
except CalledProcessError as e:
    print(e.stderr.decode())
In [25]:
makefile_contents = f"""
export TOP_DIR := {os.path.abspath('') + "/bsg_ml_atoms"}

.DEFAULT_GOAL=run

HARNESS     = test_harness_standard
CHIP_V      = chip_top.v
DEBUG       = 0
TRACE       = 0
CYCLE_LIMIT = 1000000

OBJECT_FILES = main.o data.o bsg_ftoa.o bsg_clog2.o

data.c:
	$(PY) generate_data.py > $@

clean: clean_sw clean_hw
	rm -rf data.c

include $(TOP_DIR)/test/common/mk/include.mk"""

makefile_file = NamedTemporaryFile(suffix=".mk", mode="w")
makefile_file.write(makefile_contents)
makefile_file.flush()

Resnet50

We need to get Resnet50 working for a DARPA milestone. Let's give it a go!

In [26]:
from tvm.relay.testing.resnet import get_workload as get_resnet

resnet50_module, resnet50_params = get_workload(
    layout="NHWC", image_shape=(224, 224, 3), num_layers=50
)

compiled = Compiler(resnet50_module)

c_file = NamedTemporaryFile(suffix=".c")
c_file.write(compiled.get_file().encode())
c_file.flush()

assert run(["clang-format", "-i", c_file.name]).returncode == 0

print(open(c_file.name).read())
extern void rtml_systolic_array_weight_stationary_fc(
    int hardware_id, float *out, float *activations, float *weights,
    int input_vector_size, int output_vector_size, int batch);
extern void rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
    int hardware_id, float *out, float *activations, float *weights, int h,
    int w, int kernel_h, int kernel_w, int in_channels, int out_channels,
    int stride_h, int stride_w);
extern void batchNormInference(float *X, float *Y, int N, int H, int W, int C,
                               float *gamma, float *beta, float *mu, float *var,
                               float epsilon);
extern void softmax1D(float *X, float *Y, int N);
extern void relu(float *X, float *Y, int N, int H, int W, int C);
extern void globalAvgPool(float *X, float *Y, int N, int H, int W, int C);
extern void add_with_broadcasting(float *out, float *a, float *b,
                                  int *out_shape, int out_ndims, int *a_shape,
                                  int a_ndims, int *b_shape, int b_ndims);
extern void add(float *out, float *a, float *b);
extern void maxpool2D3x3_resnet18_op6(float *X, float *Y);
extern void zero_pad_nhwc(float *out, float *in, int h, int w, int c,
                          int pad_north, int pad_east, int pad_south,
                          int pad_west);
float b93866206418016[1][224][224][3];
float b93866244513168prepadding[1][112][112][64];
float b93866244513168[1][112][112][64];
float b93866244879296[1][112][112][64];
float b93866231287584[1][112][112][64];
float b93866244505776[1][56][56][64];
float b93866244713808[1][56][56][64];
float b93866206492576[1][56][56][64];
float b93866245109856prepadding[1][56][56][64];
float b93866245109856[1][56][56][64];
float b93866244706080[1][56][56][64];
float b93866244470208[1][56][56][64];
float b93866244661376prepadding[1][56][56][64];
float b93866244661376[1][56][56][64];
float b93866244706224[1][56][56][64];
float b93866206589664[1][56][56][64];
float b93866231429968prepadding[1][56][56][256];
float b93866231429968[1][56][56][256];
float b93866245245792prepadding[1][56][56][256];
float b93866245245792[1][56][56][256];
float b93866206653824[1][56][56][256];
float b93866231076304[1][56][56][256];
float b93866244439760[1][56][56][256];
float b93866244878880prepadding[1][56][56][64];
float b93866244878880[1][56][56][64];
float b93866245579904[1][56][56][64];
float b93866244481392[1][56][56][64];
float b93866245222848prepadding[1][56][56][64];
float b93866245222848[1][56][56][64];
float b93866231199504[1][56][56][64];
float b93866193820768[1][56][56][64];
float b93866195142688prepadding[1][56][56][256];
float b93866195142688[1][56][56][256];
float b93866183970352[1][56][56][256];
float b93866484505792[1][56][56][256];
float b93866485366032[1][56][56][256];
float b93866480580848prepadding[1][56][56][64];
float b93866480580848[1][56][56][64];
float b93866485366112[1][56][56][64];
float b93866231173072[1][56][56][64];
float b93866206450336prepadding[1][56][56][64];
float b93866206450336[1][56][56][64];
float b93866485368672[1][56][56][64];
float b93866206633808[1][56][56][64];
float b93866206408432prepadding[1][56][56][256];
float b93866206408432[1][56][56][256];
float b93866245286320[1][56][56][256];
float b93866206633888[1][56][56][256];
float b93866245390688[1][56][56][256];
float b93866231346304prepadding[1][28][28][128];
float b93866231346304[1][28][28][128];
float b93866245360880[1][28][28][128];
float b93866190777696[1][28][28][128];
float b93866195267840prepadding[1][28][28][128];
float b93866195267840[1][28][28][128];
float b93866245286400[1][28][28][128];
float b93866184911712[1][28][28][128];
float b93866244701920prepadding[1][28][28][512];
float b93866244701920[1][28][28][512];
float b93866231385616prepadding[1][28][28][512];
float b93866231385616[1][28][28][512];
float b93866245013392[1][28][28][512];
float b93866245088784[1][28][28][512];
float b93866184397136[1][28][28][512];
float b93866244529040prepadding[1][28][28][128];
float b93866244529040[1][28][28][128];
float b93866231348912[1][28][28][128];
float b93866231346432[1][28][28][128];
float b93866231264176prepadding[1][28][28][128];
float b93866231264176[1][28][28][128];
float b93866206388944[1][28][28][128];
float b93866231319920[1][28][28][128];
float b93866231130672prepadding[1][28][28][512];
float b93866231130672[1][28][28][512];
float b93866231313328[1][28][28][512];
float b93866184397216[1][28][28][512];
float b93866206542112[1][28][28][512];
float b93866244522928prepadding[1][28][28][128];
float b93866244522928[1][28][28][128];
float b93866245013472[1][28][28][128];
float b93866245054544[1][28][28][128];
float b93866244957360prepadding[1][28][28][128];
float b93866244957360[1][28][28][128];
float b93866206446016[1][28][28][128];
float b93866206490384[1][28][28][128];
float b93866231174512prepadding[1][28][28][512];
float b93866231174512[1][28][28][512];
float b93866244569824[1][28][28][512];
float b93866244872736[1][28][28][512];
float b93866245083600[1][28][28][512];
float b93866245344592prepadding[1][28][28][128];
float b93866245344592[1][28][28][128];
float b93866244957440[1][28][28][128];
float b93866244883216[1][28][28][128];
float b93866245423440prepadding[1][28][28][128];
float b93866245423440[1][28][28][128];
float b93866245070704[1][28][28][128];
float b93866245071952[1][28][28][128];
float b93866245229088prepadding[1][28][28][512];
float b93866245229088[1][28][28][512];
float b93866245250784[1][28][28][512];
float b93866245094064[1][28][28][512];
float b93866480216064[1][28][28][512];
float b93866480199440prepadding[1][14][14][256];
float b93866480199440[1][14][14][256];
float b93866231147152[1][14][14][256];
float b93866486814416[1][14][14][256];
float b93866486511920prepadding[1][14][14][256];
float b93866486511920[1][14][14][256];
float b93866206542192[1][14][14][256];
float b93866484116992[1][14][14][256];
float b93866484523472prepadding[1][14][14][1024];
float b93866484523472[1][14][14][1024];
float b93866481323248prepadding[1][14][14][1024];
float b93866481323248[1][14][14][1024];
float b93866483979328[1][14][14][1024];
float b93866485655360[1][14][14][1024];
float b93866486511360[1][14][14][1024];
float b93866485850912prepadding[1][14][14][256];
float b93866485850912[1][14][14][256];
float b93866482291328[1][14][14][256];
float b93866482291392[1][14][14][256];
float b93866484465344prepadding[1][14][14][256];
float b93866484465344[1][14][14][256];
float b93866485907376[1][14][14][256];
float b93866485907440[1][14][14][256];
float b93866484446944prepadding[1][14][14][1024];
float b93866484446944[1][14][14][1024];
float b93866486349808[1][14][14][1024];
float b93866485526752[1][14][14][1024];
float b93866484489216[1][14][14][1024];
float b93866482426048prepadding[1][14][14][256];
float b93866482426048[1][14][14][256];
float b93866486262640[1][14][14][256];
float b93866485413856[1][14][14][256];
float b93866485760464prepadding[1][14][14][256];
float b93866485760464[1][14][14][256];
float b93866206632336[1][14][14][256];
float b93866244885888[1][14][14][256];
float b93866244484224prepadding[1][14][14][1024];
float b93866244484224[1][14][14][1024];
float b93866244686992[1][14][14][1024];
float b93866244922640[1][14][14][1024];
float b93866244922704[1][14][14][1024];
float b93866244606864prepadding[1][14][14][256];
float b93866244606864[1][14][14][256];
float b93866245067808[1][14][14][256];
float b93866245067872[1][14][14][256];
float b93866244615232prepadding[1][14][14][256];
float b93866244615232[1][14][14][256];
float b93866231272384[1][14][14][256];
float b93866231272448[1][14][14][256];
float b93866244633408prepadding[1][14][14][1024];
float b93866244633408[1][14][14][1024];
float b93866177335136[1][14][14][1024];
float b93866206601232[1][14][14][1024];
float b93866231320480[1][14][14][1024];
float b93866194576000prepadding[1][14][14][256];
float b93866194576000[1][14][14][256];
float b93866198383888[1][14][14][256];
float b93866184230400[1][14][14][256];
float b93866206633040prepadding[1][14][14][256];
float b93866206633040[1][14][14][256];
float b93866206633296[1][14][14][256];
float b93866206653168[1][14][14][256];
float b93866206624768prepadding[1][14][14][1024];
float b93866206624768[1][14][14][1024];
float b93866231120320[1][14][14][1024];
float b93866244690272[1][14][14][1024];
float b93866244690336[1][14][14][1024];
float b93866244552384prepadding[1][14][14][256];
float b93866244552384[1][14][14][256];
float b93866182786640[1][14][14][256];
float b93866182786704[1][14][14][256];
float b93866206585552prepadding[1][14][14][256];
float b93866206585552[1][14][14][256];
float b93866206411376[1][14][14][256];
float b93866206411440[1][14][14][256];
float b93866244721616prepadding[1][14][14][1024];
float b93866244721616[1][14][14][1024];
float b93866231132816[1][14][14][1024];
float b93866342138256[1][14][14][1024];
float b93866206640368[1][14][14][1024];
float b93866206549072prepadding[1][7][7][512];
float b93866206549072[1][7][7][512];
float b93866206598912[1][7][7][512];
float b93866231447296[1][7][7][512];
float b93866206498576prepadding[1][7][7][512];
float b93866206498576[1][7][7][512];
float b93866231446320[1][7][7][512];
float b93866206510752[1][7][7][512];
float b93866244612880prepadding[1][7][7][2048];
float b93866244612880[1][7][7][2048];
float b93866206487584prepadding[1][7][7][2048];
float b93866206487584[1][7][7][2048];
float b93866206491520[1][7][7][2048];
float b93866231133408[1][7][7][2048];
float b93866245073152[1][7][7][2048];
float b93866245455904prepadding[1][7][7][512];
float b93866245455904[1][7][7][512];
float b93866245458544[1][7][7][512];
float b93866245087440[1][7][7][512];
float b93866245086736prepadding[1][7][7][512];
float b93866245086736[1][7][7][512];
float b93866245463600[1][7][7][512];
float b93866244639888[1][7][7][512];
float b93866244962896prepadding[1][7][7][2048];
float b93866244962896[1][7][7][2048];
float b93866245466320[1][7][7][2048];
float b93866231234240[1][7][7][2048];
float b93866231234304[1][7][7][2048];
float b93866231202304prepadding[1][7][7][512];
float b93866231202304[1][7][7][512];
float b93866245478848[1][7][7][512];
float b93866245478912[1][7][7][512];
float b93866244748064prepadding[1][7][7][512];
float b93866244748064[1][7][7][512];
float b93866231393808[1][7][7][512];
float b93866231393872[1][7][7][512];
float b93866206558192prepadding[1][7][7][2048];
float b93866206558192[1][7][7][2048];
float b93866231353280[1][7][7][2048];
float b93866206452160[1][7][7][2048];
float b93866231107136[1][7][7][2048];
float b93866231414832[1][1][1][2048];
float b93866197774320[1][1000];
float b93866244499936[1][1000];

void compiled(
    float *out, float *data, float *bn_data_gamma, float *bn_data_beta,
    float *bn_data_moving_mean, float *bn_data_moving_var, float *conv0_weight,
    float *bn0_gamma, float *bn0_beta, float *bn0_moving_mean,
    float *bn0_moving_var, float *stage1_unit1_bn1_gamma,
    float *stage1_unit1_bn1_beta, float *stage1_unit1_bn1_moving_mean,
    float *stage1_unit1_bn1_moving_var, float *stage1_unit1_conv1_weight,
    float *stage1_unit1_bn2_gamma, float *stage1_unit1_bn2_beta,
    float *stage1_unit1_bn2_moving_mean, float *stage1_unit1_bn2_moving_var,
    float *stage1_unit1_conv2_weight, float *stage1_unit1_bn3_gamma,
    float *stage1_unit1_bn3_beta, float *stage1_unit1_bn3_moving_mean,
    float *stage1_unit1_bn3_moving_var, float *stage1_unit1_conv3_weight,
    float *stage1_unit1_sc_weight, float *stage1_unit2_bn1_gamma,
    float *stage1_unit2_bn1_beta, float *stage1_unit2_bn1_moving_mean,
    float *stage1_unit2_bn1_moving_var, float *stage1_unit2_conv1_weight,
    float *stage1_unit2_bn2_gamma, float *stage1_unit2_bn2_beta,
    float *stage1_unit2_bn2_moving_mean, float *stage1_unit2_bn2_moving_var,
    float *stage1_unit2_conv2_weight, float *stage1_unit2_bn3_gamma,
    float *stage1_unit2_bn3_beta, float *stage1_unit2_bn3_moving_mean,
    float *stage1_unit2_bn3_moving_var, float *stage1_unit2_conv3_weight,
    float *stage1_unit3_bn1_gamma, float *stage1_unit3_bn1_beta,
    float *stage1_unit3_bn1_moving_mean, float *stage1_unit3_bn1_moving_var,
    float *stage1_unit3_conv1_weight, float *stage1_unit3_bn2_gamma,
    float *stage1_unit3_bn2_beta, float *stage1_unit3_bn2_moving_mean,
    float *stage1_unit3_bn2_moving_var, float *stage1_unit3_conv2_weight,
    float *stage1_unit3_bn3_gamma, float *stage1_unit3_bn3_beta,
    float *stage1_unit3_bn3_moving_mean, float *stage1_unit3_bn3_moving_var,
    float *stage1_unit3_conv3_weight, float *stage2_unit1_bn1_gamma,
    float *stage2_unit1_bn1_beta, float *stage2_unit1_bn1_moving_mean,
    float *stage2_unit1_bn1_moving_var, float *stage2_unit1_conv1_weight,
    float *stage2_unit1_bn2_gamma, float *stage2_unit1_bn2_beta,
    float *stage2_unit1_bn2_moving_mean, float *stage2_unit1_bn2_moving_var,
    float *stage2_unit1_conv2_weight, float *stage2_unit1_bn3_gamma,
    float *stage2_unit1_bn3_beta, float *stage2_unit1_bn3_moving_mean,
    float *stage2_unit1_bn3_moving_var, float *stage2_unit1_conv3_weight,
    float *stage2_unit1_sc_weight, float *stage2_unit2_bn1_gamma,
    float *stage2_unit2_bn1_beta, float *stage2_unit2_bn1_moving_mean,
    float *stage2_unit2_bn1_moving_var, float *stage2_unit2_conv1_weight,
    float *stage2_unit2_bn2_gamma, float *stage2_unit2_bn2_beta,
    float *stage2_unit2_bn2_moving_mean, float *stage2_unit2_bn2_moving_var,
    float *stage2_unit2_conv2_weight, float *stage2_unit2_bn3_gamma,
    float *stage2_unit2_bn3_beta, float *stage2_unit2_bn3_moving_mean,
    float *stage2_unit2_bn3_moving_var, float *stage2_unit2_conv3_weight,
    float *stage2_unit3_bn1_gamma, float *stage2_unit3_bn1_beta,
    float *stage2_unit3_bn1_moving_mean, float *stage2_unit3_bn1_moving_var,
    float *stage2_unit3_conv1_weight, float *stage2_unit3_bn2_gamma,
    float *stage2_unit3_bn2_beta, float *stage2_unit3_bn2_moving_mean,
    float *stage2_unit3_bn2_moving_var, float *stage2_unit3_conv2_weight,
    float *stage2_unit3_bn3_gamma, float *stage2_unit3_bn3_beta,
    float *stage2_unit3_bn3_moving_mean, float *stage2_unit3_bn3_moving_var,
    float *stage2_unit3_conv3_weight, float *stage2_unit4_bn1_gamma,
    float *stage2_unit4_bn1_beta, float *stage2_unit4_bn1_moving_mean,
    float *stage2_unit4_bn1_moving_var, float *stage2_unit4_conv1_weight,
    float *stage2_unit4_bn2_gamma, float *stage2_unit4_bn2_beta,
    float *stage2_unit4_bn2_moving_mean, float *stage2_unit4_bn2_moving_var,
    float *stage2_unit4_conv2_weight, float *stage2_unit4_bn3_gamma,
    float *stage2_unit4_bn3_beta, float *stage2_unit4_bn3_moving_mean,
    float *stage2_unit4_bn3_moving_var, float *stage2_unit4_conv3_weight,
    float *stage3_unit1_bn1_gamma, float *stage3_unit1_bn1_beta,
    float *stage3_unit1_bn1_moving_mean, float *stage3_unit1_bn1_moving_var,
    float *stage3_unit1_conv1_weight, float *stage3_unit1_bn2_gamma,
    float *stage3_unit1_bn2_beta, float *stage3_unit1_bn2_moving_mean,
    float *stage3_unit1_bn2_moving_var, float *stage3_unit1_conv2_weight,
    float *stage3_unit1_bn3_gamma, float *stage3_unit1_bn3_beta,
    float *stage3_unit1_bn3_moving_mean, float *stage3_unit1_bn3_moving_var,
    float *stage3_unit1_conv3_weight, float *stage3_unit1_sc_weight,
    float *stage3_unit2_bn1_gamma, float *stage3_unit2_bn1_beta,
    float *stage3_unit2_bn1_moving_mean, float *stage3_unit2_bn1_moving_var,
    float *stage3_unit2_conv1_weight, float *stage3_unit2_bn2_gamma,
    float *stage3_unit2_bn2_beta, float *stage3_unit2_bn2_moving_mean,
    float *stage3_unit2_bn2_moving_var, float *stage3_unit2_conv2_weight,
    float *stage3_unit2_bn3_gamma, float *stage3_unit2_bn3_beta,
    float *stage3_unit2_bn3_moving_mean, float *stage3_unit2_bn3_moving_var,
    float *stage3_unit2_conv3_weight, float *stage3_unit3_bn1_gamma,
    float *stage3_unit3_bn1_beta, float *stage3_unit3_bn1_moving_mean,
    float *stage3_unit3_bn1_moving_var, float *stage3_unit3_conv1_weight,
    float *stage3_unit3_bn2_gamma, float *stage3_unit3_bn2_beta,
    float *stage3_unit3_bn2_moving_mean, float *stage3_unit3_bn2_moving_var,
    float *stage3_unit3_conv2_weight, float *stage3_unit3_bn3_gamma,
    float *stage3_unit3_bn3_beta, float *stage3_unit3_bn3_moving_mean,
    float *stage3_unit3_bn3_moving_var, float *stage3_unit3_conv3_weight,
    float *stage3_unit4_bn1_gamma, float *stage3_unit4_bn1_beta,
    float *stage3_unit4_bn1_moving_mean, float *stage3_unit4_bn1_moving_var,
    float *stage3_unit4_conv1_weight, float *stage3_unit4_bn2_gamma,
    float *stage3_unit4_bn2_beta, float *stage3_unit4_bn2_moving_mean,
    float *stage3_unit4_bn2_moving_var, float *stage3_unit4_conv2_weight,
    float *stage3_unit4_bn3_gamma, float *stage3_unit4_bn3_beta,
    float *stage3_unit4_bn3_moving_mean, float *stage3_unit4_bn3_moving_var,
    float *stage3_unit4_conv3_weight, float *stage3_unit5_bn1_gamma,
    float *stage3_unit5_bn1_beta, float *stage3_unit5_bn1_moving_mean,
    float *stage3_unit5_bn1_moving_var, float *stage3_unit5_conv1_weight,
    float *stage3_unit5_bn2_gamma, float *stage3_unit5_bn2_beta,
    float *stage3_unit5_bn2_moving_mean, float *stage3_unit5_bn2_moving_var,
    float *stage3_unit5_conv2_weight, float *stage3_unit5_bn3_gamma,
    float *stage3_unit5_bn3_beta, float *stage3_unit5_bn3_moving_mean,
    float *stage3_unit5_bn3_moving_var, float *stage3_unit5_conv3_weight,
    float *stage3_unit6_bn1_gamma, float *stage3_unit6_bn1_beta,
    float *stage3_unit6_bn1_moving_mean, float *stage3_unit6_bn1_moving_var,
    float *stage3_unit6_conv1_weight, float *stage3_unit6_bn2_gamma,
    float *stage3_unit6_bn2_beta, float *stage3_unit6_bn2_moving_mean,
    float *stage3_unit6_bn2_moving_var, float *stage3_unit6_conv2_weight,
    float *stage3_unit6_bn3_gamma, float *stage3_unit6_bn3_beta,
    float *stage3_unit6_bn3_moving_mean, float *stage3_unit6_bn3_moving_var,
    float *stage3_unit6_conv3_weight, float *stage4_unit1_bn1_gamma,
    float *stage4_unit1_bn1_beta, float *stage4_unit1_bn1_moving_mean,
    float *stage4_unit1_bn1_moving_var, float *stage4_unit1_conv1_weight,
    float *stage4_unit1_bn2_gamma, float *stage4_unit1_bn2_beta,
    float *stage4_unit1_bn2_moving_mean, float *stage4_unit1_bn2_moving_var,
    float *stage4_unit1_conv2_weight, float *stage4_unit1_bn3_gamma,
    float *stage4_unit1_bn3_beta, float *stage4_unit1_bn3_moving_mean,
    float *stage4_unit1_bn3_moving_var, float *stage4_unit1_conv3_weight,
    float *stage4_unit1_sc_weight, float *stage4_unit2_bn1_gamma,
    float *stage4_unit2_bn1_beta, float *stage4_unit2_bn1_moving_mean,
    float *stage4_unit2_bn1_moving_var, float *stage4_unit2_conv1_weight,
    float *stage4_unit2_bn2_gamma, float *stage4_unit2_bn2_beta,
    float *stage4_unit2_bn2_moving_mean, float *stage4_unit2_bn2_moving_var,
    float *stage4_unit2_conv2_weight, float *stage4_unit2_bn3_gamma,
    float *stage4_unit2_bn3_beta, float *stage4_unit2_bn3_moving_mean,
    float *stage4_unit2_bn3_moving_var, float *stage4_unit2_conv3_weight,
    float *stage4_unit3_bn1_gamma, float *stage4_unit3_bn1_beta,
    float *stage4_unit3_bn1_moving_mean, float *stage4_unit3_bn1_moving_var,
    float *stage4_unit3_conv1_weight, float *stage4_unit3_bn2_gamma,
    float *stage4_unit3_bn2_beta, float *stage4_unit3_bn2_moving_mean,
    float *stage4_unit3_bn2_moving_var, float *stage4_unit3_conv2_weight,
    float *stage4_unit3_bn3_gamma, float *stage4_unit3_bn3_beta,
    float *stage4_unit3_bn3_moving_mean, float *stage4_unit3_bn3_moving_var,
    float *stage4_unit3_conv3_weight, float *bn1_gamma, float *bn1_beta,
    float *bn1_moving_mean, float *bn1_moving_var, float *fc1_weight,
    float *fc1_bias) {

  batchNormInference((float *)data, (float *)b93866206418016, 1, 224, 224, 3,
                     bn_data_gamma, bn_data_beta, bn_data_moving_mean,
                     bn_data_moving_var, 0.00002);

  zero_pad_nhwc((float *)b93866244513168prepadding, (float *)b93866206418016,
                224, 224, 3, 3, 3, 3, 3);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244513168, (float *)b93866244513168prepadding,
      conv0_weight, 230, 230, 7, 7, 3, 64, 2, 2);

  batchNormInference((float *)b93866244513168, (float *)b93866244879296, 1, 112,
                     112, 64, bn0_gamma, bn0_beta, bn0_moving_mean,
                     bn0_moving_var, 0.00002);

  relu((float *)b93866244879296, (float *)b93866231287584, 1, 112, 112, 64);

  maxpool2D3x3_resnet18_op6((float *)b93866231287584, (float *)b93866244505776);

  batchNormInference((float *)b93866244505776, (float *)b93866244713808, 1, 56,
                     56, 64, stage1_unit1_bn1_gamma, stage1_unit1_bn1_beta,
                     stage1_unit1_bn1_moving_mean, stage1_unit1_bn1_moving_var,
                     0.00002);

  relu((float *)b93866244713808, (float *)b93866206492576, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866245109856prepadding, (float *)b93866206492576,
                56, 56, 64, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866245109856, (float *)b93866245109856prepadding,
      stage1_unit1_conv1_weight, 56, 56, 1, 1, 64, 64, 1, 1);

  batchNormInference((float *)b93866245109856, (float *)b93866244706080, 1, 56,
                     56, 64, stage1_unit1_bn2_gamma, stage1_unit1_bn2_beta,
                     stage1_unit1_bn2_moving_mean, stage1_unit1_bn2_moving_var,
                     0.00002);

  relu((float *)b93866244706080, (float *)b93866244470208, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866244661376prepadding, (float *)b93866244470208,
                56, 56, 64, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244661376, (float *)b93866244661376prepadding,
      stage1_unit1_conv2_weight, 58, 58, 3, 3, 64, 64, 1, 1);

  batchNormInference((float *)b93866244661376, (float *)b93866244706224, 1, 56,
                     56, 64, stage1_unit1_bn3_gamma, stage1_unit1_bn3_beta,
                     stage1_unit1_bn3_moving_mean, stage1_unit1_bn3_moving_var,
                     0.00002);

  relu((float *)b93866244706224, (float *)b93866206589664, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866231429968prepadding, (float *)b93866206589664,
                56, 56, 64, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866231429968, (float *)b93866231429968prepadding,
      stage1_unit1_conv3_weight, 56, 56, 1, 1, 64, 256, 1, 1);

  zero_pad_nhwc((float *)b93866245245792prepadding, (float *)b93866206492576,
                56, 56, 64, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866245245792, (float *)b93866245245792prepadding,
      stage1_unit1_sc_weight, 56, 56, 1, 1, 64, 256, 1, 1);

  int b93866206653824_out_shape[4] = {1, 56, 56, 256};
  int b93866206653824_a_shape[4] = {1, 56, 56, 256};
  int b93866206653824_b_shape[4] = {1, 56, 56, 256};
  add_with_broadcasting(
      (float *)b93866206653824, (float *)b93866231429968,
      (float *)b93866245245792, (int *)b93866206653824_out_shape, 4,
      (int *)b93866206653824_a_shape, 4, (int *)b93866206653824_b_shape, 4);

  batchNormInference((float *)b93866206653824, (float *)b93866231076304, 1, 56,
                     56, 256, stage1_unit2_bn1_gamma, stage1_unit2_bn1_beta,
                     stage1_unit2_bn1_moving_mean, stage1_unit2_bn1_moving_var,
                     0.00002);

  relu((float *)b93866231076304, (float *)b93866244439760, 1, 56, 56, 256);

  zero_pad_nhwc((float *)b93866244878880prepadding, (float *)b93866244439760,
                56, 56, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244878880, (float *)b93866244878880prepadding,
      stage1_unit2_conv1_weight, 56, 56, 1, 1, 256, 64, 1, 1);

  batchNormInference((float *)b93866244878880, (float *)b93866245579904, 1, 56,
                     56, 64, stage1_unit2_bn2_gamma, stage1_unit2_bn2_beta,
                     stage1_unit2_bn2_moving_mean, stage1_unit2_bn2_moving_var,
                     0.00002);

  relu((float *)b93866245579904, (float *)b93866244481392, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866245222848prepadding, (float *)b93866244481392,
                56, 56, 64, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866245222848, (float *)b93866245222848prepadding,
      stage1_unit2_conv2_weight, 58, 58, 3, 3, 64, 64, 1, 1);

  batchNormInference((float *)b93866245222848, (float *)b93866231199504, 1, 56,
                     56, 64, stage1_unit2_bn3_gamma, stage1_unit2_bn3_beta,
                     stage1_unit2_bn3_moving_mean, stage1_unit2_bn3_moving_var,
                     0.00002);

  relu((float *)b93866231199504, (float *)b93866193820768, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866195142688prepadding, (float *)b93866193820768,
                56, 56, 64, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866195142688, (float *)b93866195142688prepadding,
      stage1_unit2_conv3_weight, 56, 56, 1, 1, 64, 256, 1, 1);

  int b93866183970352_out_shape[4] = {1, 56, 56, 256};
  int b93866183970352_a_shape[4] = {1, 56, 56, 256};
  int b93866183970352_b_shape[4] = {1, 56, 56, 256};
  add_with_broadcasting(
      (float *)b93866183970352, (float *)b93866195142688,
      (float *)b93866206653824, (int *)b93866183970352_out_shape, 4,
      (int *)b93866183970352_a_shape, 4, (int *)b93866183970352_b_shape, 4);

  batchNormInference((float *)b93866183970352, (float *)b93866484505792, 1, 56,
                     56, 256, stage1_unit3_bn1_gamma, stage1_unit3_bn1_beta,
                     stage1_unit3_bn1_moving_mean, stage1_unit3_bn1_moving_var,
                     0.00002);

  relu((float *)b93866484505792, (float *)b93866485366032, 1, 56, 56, 256);

  zero_pad_nhwc((float *)b93866480580848prepadding, (float *)b93866485366032,
                56, 56, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866480580848, (float *)b93866480580848prepadding,
      stage1_unit3_conv1_weight, 56, 56, 1, 1, 256, 64, 1, 1);

  batchNormInference((float *)b93866480580848, (float *)b93866485366112, 1, 56,
                     56, 64, stage1_unit3_bn2_gamma, stage1_unit3_bn2_beta,
                     stage1_unit3_bn2_moving_mean, stage1_unit3_bn2_moving_var,
                     0.00002);

  relu((float *)b93866485366112, (float *)b93866231173072, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866206450336prepadding, (float *)b93866231173072,
                56, 56, 64, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206450336, (float *)b93866206450336prepadding,
      stage1_unit3_conv2_weight, 58, 58, 3, 3, 64, 64, 1, 1);

  batchNormInference((float *)b93866206450336, (float *)b93866485368672, 1, 56,
                     56, 64, stage1_unit3_bn3_gamma, stage1_unit3_bn3_beta,
                     stage1_unit3_bn3_moving_mean, stage1_unit3_bn3_moving_var,
                     0.00002);

  relu((float *)b93866485368672, (float *)b93866206633808, 1, 56, 56, 64);

  zero_pad_nhwc((float *)b93866206408432prepadding, (float *)b93866206633808,
                56, 56, 64, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206408432, (float *)b93866206408432prepadding,
      stage1_unit3_conv3_weight, 56, 56, 1, 1, 64, 256, 1, 1);

  int b93866245286320_out_shape[4] = {1, 56, 56, 256};
  int b93866245286320_a_shape[4] = {1, 56, 56, 256};
  int b93866245286320_b_shape[4] = {1, 56, 56, 256};
  add_with_broadcasting(
      (float *)b93866245286320, (float *)b93866206408432,
      (float *)b93866183970352, (int *)b93866245286320_out_shape, 4,
      (int *)b93866245286320_a_shape, 4, (int *)b93866245286320_b_shape, 4);

  batchNormInference((float *)b93866245286320, (float *)b93866206633888, 1, 56,
                     56, 256, stage2_unit1_bn1_gamma, stage2_unit1_bn1_beta,
                     stage2_unit1_bn1_moving_mean, stage2_unit1_bn1_moving_var,
                     0.00002);

  relu((float *)b93866206633888, (float *)b93866245390688, 1, 56, 56, 256);

  zero_pad_nhwc((float *)b93866231346304prepadding, (float *)b93866245390688,
                56, 56, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866231346304, (float *)b93866231346304prepadding,
      stage2_unit1_conv1_weight, 56, 56, 1, 1, 256, 128, 2, 2);

  batchNormInference((float *)b93866231346304, (float *)b93866245360880, 1, 28,
                     28, 128, stage2_unit1_bn2_gamma, stage2_unit1_bn2_beta,
                     stage2_unit1_bn2_moving_mean, stage2_unit1_bn2_moving_var,
                     0.00002);

  relu((float *)b93866245360880, (float *)b93866190777696, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866195267840prepadding, (float *)b93866190777696,
                28, 28, 128, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866195267840, (float *)b93866195267840prepadding,
      stage2_unit1_conv2_weight, 30, 30, 3, 3, 128, 128, 1, 1);

  batchNormInference((float *)b93866195267840, (float *)b93866245286400, 1, 28,
                     28, 128, stage2_unit1_bn3_gamma, stage2_unit1_bn3_beta,
                     stage2_unit1_bn3_moving_mean, stage2_unit1_bn3_moving_var,
                     0.00002);

  relu((float *)b93866245286400, (float *)b93866184911712, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866244701920prepadding, (float *)b93866184911712,
                28, 28, 128, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244701920, (float *)b93866244701920prepadding,
      stage2_unit1_conv3_weight, 28, 28, 1, 1, 128, 512, 1, 1);

  zero_pad_nhwc((float *)b93866231385616prepadding, (float *)b93866245390688,
                56, 56, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866231385616, (float *)b93866231385616prepadding,
      stage2_unit1_sc_weight, 56, 56, 1, 1, 256, 512, 2, 2);

  int b93866245013392_out_shape[4] = {1, 28, 28, 512};
  int b93866245013392_a_shape[4] = {1, 28, 28, 512};
  int b93866245013392_b_shape[4] = {1, 28, 28, 512};
  add_with_broadcasting(
      (float *)b93866245013392, (float *)b93866244701920,
      (float *)b93866231385616, (int *)b93866245013392_out_shape, 4,
      (int *)b93866245013392_a_shape, 4, (int *)b93866245013392_b_shape, 4);

  batchNormInference((float *)b93866245013392, (float *)b93866245088784, 1, 28,
                     28, 512, stage2_unit2_bn1_gamma, stage2_unit2_bn1_beta,
                     stage2_unit2_bn1_moving_mean, stage2_unit2_bn1_moving_var,
                     0.00002);

  relu((float *)b93866245088784, (float *)b93866184397136, 1, 28, 28, 512);

  zero_pad_nhwc((float *)b93866244529040prepadding, (float *)b93866184397136,
                28, 28, 512, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244529040, (float *)b93866244529040prepadding,
      stage2_unit2_conv1_weight, 28, 28, 1, 1, 512, 128, 1, 1);

  batchNormInference((float *)b93866244529040, (float *)b93866231348912, 1, 28,
                     28, 128, stage2_unit2_bn2_gamma, stage2_unit2_bn2_beta,
                     stage2_unit2_bn2_moving_mean, stage2_unit2_bn2_moving_var,
                     0.00002);

  relu((float *)b93866231348912, (float *)b93866231346432, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866231264176prepadding, (float *)b93866231346432,
                28, 28, 128, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866231264176, (float *)b93866231264176prepadding,
      stage2_unit2_conv2_weight, 30, 30, 3, 3, 128, 128, 1, 1);

  batchNormInference((float *)b93866231264176, (float *)b93866206388944, 1, 28,
                     28, 128, stage2_unit2_bn3_gamma, stage2_unit2_bn3_beta,
                     stage2_unit2_bn3_moving_mean, stage2_unit2_bn3_moving_var,
                     0.00002);

  relu((float *)b93866206388944, (float *)b93866231319920, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866231130672prepadding, (float *)b93866231319920,
                28, 28, 128, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866231130672, (float *)b93866231130672prepadding,
      stage2_unit2_conv3_weight, 28, 28, 1, 1, 128, 512, 1, 1);

  int b93866231313328_out_shape[4] = {1, 28, 28, 512};
  int b93866231313328_a_shape[4] = {1, 28, 28, 512};
  int b93866231313328_b_shape[4] = {1, 28, 28, 512};
  add_with_broadcasting(
      (float *)b93866231313328, (float *)b93866231130672,
      (float *)b93866245013392, (int *)b93866231313328_out_shape, 4,
      (int *)b93866231313328_a_shape, 4, (int *)b93866231313328_b_shape, 4);

  batchNormInference((float *)b93866231313328, (float *)b93866184397216, 1, 28,
                     28, 512, stage2_unit3_bn1_gamma, stage2_unit3_bn1_beta,
                     stage2_unit3_bn1_moving_mean, stage2_unit3_bn1_moving_var,
                     0.00002);

  relu((float *)b93866184397216, (float *)b93866206542112, 1, 28, 28, 512);

  zero_pad_nhwc((float *)b93866244522928prepadding, (float *)b93866206542112,
                28, 28, 512, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244522928, (float *)b93866244522928prepadding,
      stage2_unit3_conv1_weight, 28, 28, 1, 1, 512, 128, 1, 1);

  batchNormInference((float *)b93866244522928, (float *)b93866245013472, 1, 28,
                     28, 128, stage2_unit3_bn2_gamma, stage2_unit3_bn2_beta,
                     stage2_unit3_bn2_moving_mean, stage2_unit3_bn2_moving_var,
                     0.00002);

  relu((float *)b93866245013472, (float *)b93866245054544, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866244957360prepadding, (float *)b93866245054544,
                28, 28, 128, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244957360, (float *)b93866244957360prepadding,
      stage2_unit3_conv2_weight, 30, 30, 3, 3, 128, 128, 1, 1);

  batchNormInference((float *)b93866244957360, (float *)b93866206446016, 1, 28,
                     28, 128, stage2_unit3_bn3_gamma, stage2_unit3_bn3_beta,
                     stage2_unit3_bn3_moving_mean, stage2_unit3_bn3_moving_var,
                     0.00002);

  relu((float *)b93866206446016, (float *)b93866206490384, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866231174512prepadding, (float *)b93866206490384,
                28, 28, 128, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866231174512, (float *)b93866231174512prepadding,
      stage2_unit3_conv3_weight, 28, 28, 1, 1, 128, 512, 1, 1);

  int b93866244569824_out_shape[4] = {1, 28, 28, 512};
  int b93866244569824_a_shape[4] = {1, 28, 28, 512};
  int b93866244569824_b_shape[4] = {1, 28, 28, 512};
  add_with_broadcasting(
      (float *)b93866244569824, (float *)b93866231174512,
      (float *)b93866231313328, (int *)b93866244569824_out_shape, 4,
      (int *)b93866244569824_a_shape, 4, (int *)b93866244569824_b_shape, 4);

  batchNormInference((float *)b93866244569824, (float *)b93866244872736, 1, 28,
                     28, 512, stage2_unit4_bn1_gamma, stage2_unit4_bn1_beta,
                     stage2_unit4_bn1_moving_mean, stage2_unit4_bn1_moving_var,
                     0.00002);

  relu((float *)b93866244872736, (float *)b93866245083600, 1, 28, 28, 512);

  zero_pad_nhwc((float *)b93866245344592prepadding, (float *)b93866245083600,
                28, 28, 512, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866245344592, (float *)b93866245344592prepadding,
      stage2_unit4_conv1_weight, 28, 28, 1, 1, 512, 128, 1, 1);

  batchNormInference((float *)b93866245344592, (float *)b93866244957440, 1, 28,
                     28, 128, stage2_unit4_bn2_gamma, stage2_unit4_bn2_beta,
                     stage2_unit4_bn2_moving_mean, stage2_unit4_bn2_moving_var,
                     0.00002);

  relu((float *)b93866244957440, (float *)b93866244883216, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866245423440prepadding, (float *)b93866244883216,
                28, 28, 128, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866245423440, (float *)b93866245423440prepadding,
      stage2_unit4_conv2_weight, 30, 30, 3, 3, 128, 128, 1, 1);

  batchNormInference((float *)b93866245423440, (float *)b93866245070704, 1, 28,
                     28, 128, stage2_unit4_bn3_gamma, stage2_unit4_bn3_beta,
                     stage2_unit4_bn3_moving_mean, stage2_unit4_bn3_moving_var,
                     0.00002);

  relu((float *)b93866245070704, (float *)b93866245071952, 1, 28, 28, 128);

  zero_pad_nhwc((float *)b93866245229088prepadding, (float *)b93866245071952,
                28, 28, 128, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866245229088, (float *)b93866245229088prepadding,
      stage2_unit4_conv3_weight, 28, 28, 1, 1, 128, 512, 1, 1);

  int b93866245250784_out_shape[4] = {1, 28, 28, 512};
  int b93866245250784_a_shape[4] = {1, 28, 28, 512};
  int b93866245250784_b_shape[4] = {1, 28, 28, 512};
  add_with_broadcasting(
      (float *)b93866245250784, (float *)b93866245229088,
      (float *)b93866244569824, (int *)b93866245250784_out_shape, 4,
      (int *)b93866245250784_a_shape, 4, (int *)b93866245250784_b_shape, 4);

  batchNormInference((float *)b93866245250784, (float *)b93866245094064, 1, 28,
                     28, 512, stage3_unit1_bn1_gamma, stage3_unit1_bn1_beta,
                     stage3_unit1_bn1_moving_mean, stage3_unit1_bn1_moving_var,
                     0.00002);

  relu((float *)b93866245094064, (float *)b93866480216064, 1, 28, 28, 512);

  zero_pad_nhwc((float *)b93866480199440prepadding, (float *)b93866480216064,
                28, 28, 512, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866480199440, (float *)b93866480199440prepadding,
      stage3_unit1_conv1_weight, 28, 28, 1, 1, 512, 256, 2, 2);

  batchNormInference((float *)b93866480199440, (float *)b93866231147152, 1, 14,
                     14, 256, stage3_unit1_bn2_gamma, stage3_unit1_bn2_beta,
                     stage3_unit1_bn2_moving_mean, stage3_unit1_bn2_moving_var,
                     0.00002);

  relu((float *)b93866231147152, (float *)b93866486814416, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866486511920prepadding, (float *)b93866486814416,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866486511920, (float *)b93866486511920prepadding,
      stage3_unit1_conv2_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  batchNormInference((float *)b93866486511920, (float *)b93866206542192, 1, 14,
                     14, 256, stage3_unit1_bn3_gamma, stage3_unit1_bn3_beta,
                     stage3_unit1_bn3_moving_mean, stage3_unit1_bn3_moving_var,
                     0.00002);

  relu((float *)b93866206542192, (float *)b93866484116992, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866484523472prepadding, (float *)b93866484116992,
                14, 14, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866484523472, (float *)b93866484523472prepadding,
      stage3_unit1_conv3_weight, 14, 14, 1, 1, 256, 1024, 1, 1);

  zero_pad_nhwc((float *)b93866481323248prepadding, (float *)b93866480216064,
                28, 28, 512, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866481323248, (float *)b93866481323248prepadding,
      stage3_unit1_sc_weight, 28, 28, 1, 1, 512, 1024, 2, 2);

  int b93866483979328_out_shape[4] = {1, 14, 14, 1024};
  int b93866483979328_a_shape[4] = {1, 14, 14, 1024};
  int b93866483979328_b_shape[4] = {1, 14, 14, 1024};
  add_with_broadcasting(
      (float *)b93866483979328, (float *)b93866484523472,
      (float *)b93866481323248, (int *)b93866483979328_out_shape, 4,
      (int *)b93866483979328_a_shape, 4, (int *)b93866483979328_b_shape, 4);

  batchNormInference((float *)b93866483979328, (float *)b93866485655360, 1, 14,
                     14, 1024, stage3_unit2_bn1_gamma, stage3_unit2_bn1_beta,
                     stage3_unit2_bn1_moving_mean, stage3_unit2_bn1_moving_var,
                     0.00002);

  relu((float *)b93866485655360, (float *)b93866486511360, 1, 14, 14, 1024);

  zero_pad_nhwc((float *)b93866485850912prepadding, (float *)b93866486511360,
                14, 14, 1024, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866485850912, (float *)b93866485850912prepadding,
      stage3_unit2_conv1_weight, 14, 14, 1, 1, 1024, 256, 1, 1);

  batchNormInference((float *)b93866485850912, (float *)b93866482291328, 1, 14,
                     14, 256, stage3_unit2_bn2_gamma, stage3_unit2_bn2_beta,
                     stage3_unit2_bn2_moving_mean, stage3_unit2_bn2_moving_var,
                     0.00002);

  relu((float *)b93866482291328, (float *)b93866482291392, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866484465344prepadding, (float *)b93866482291392,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866484465344, (float *)b93866484465344prepadding,
      stage3_unit2_conv2_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  batchNormInference((float *)b93866484465344, (float *)b93866485907376, 1, 14,
                     14, 256, stage3_unit2_bn3_gamma, stage3_unit2_bn3_beta,
                     stage3_unit2_bn3_moving_mean, stage3_unit2_bn3_moving_var,
                     0.00002);

  relu((float *)b93866485907376, (float *)b93866485907440, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866484446944prepadding, (float *)b93866485907440,
                14, 14, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866484446944, (float *)b93866484446944prepadding,
      stage3_unit2_conv3_weight, 14, 14, 1, 1, 256, 1024, 1, 1);

  int b93866486349808_out_shape[4] = {1, 14, 14, 1024};
  int b93866486349808_a_shape[4] = {1, 14, 14, 1024};
  int b93866486349808_b_shape[4] = {1, 14, 14, 1024};
  add_with_broadcasting(
      (float *)b93866486349808, (float *)b93866484446944,
      (float *)b93866483979328, (int *)b93866486349808_out_shape, 4,
      (int *)b93866486349808_a_shape, 4, (int *)b93866486349808_b_shape, 4);

  batchNormInference((float *)b93866486349808, (float *)b93866485526752, 1, 14,
                     14, 1024, stage3_unit3_bn1_gamma, stage3_unit3_bn1_beta,
                     stage3_unit3_bn1_moving_mean, stage3_unit3_bn1_moving_var,
                     0.00002);

  relu((float *)b93866485526752, (float *)b93866484489216, 1, 14, 14, 1024);

  zero_pad_nhwc((float *)b93866482426048prepadding, (float *)b93866484489216,
                14, 14, 1024, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866482426048, (float *)b93866482426048prepadding,
      stage3_unit3_conv1_weight, 14, 14, 1, 1, 1024, 256, 1, 1);

  batchNormInference((float *)b93866482426048, (float *)b93866486262640, 1, 14,
                     14, 256, stage3_unit3_bn2_gamma, stage3_unit3_bn2_beta,
                     stage3_unit3_bn2_moving_mean, stage3_unit3_bn2_moving_var,
                     0.00002);

  relu((float *)b93866486262640, (float *)b93866485413856, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866485760464prepadding, (float *)b93866485413856,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866485760464, (float *)b93866485760464prepadding,
      stage3_unit3_conv2_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  batchNormInference((float *)b93866485760464, (float *)b93866206632336, 1, 14,
                     14, 256, stage3_unit3_bn3_gamma, stage3_unit3_bn3_beta,
                     stage3_unit3_bn3_moving_mean, stage3_unit3_bn3_moving_var,
                     0.00002);

  relu((float *)b93866206632336, (float *)b93866244885888, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866244484224prepadding, (float *)b93866244885888,
                14, 14, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244484224, (float *)b93866244484224prepadding,
      stage3_unit3_conv3_weight, 14, 14, 1, 1, 256, 1024, 1, 1);

  int b93866244686992_out_shape[4] = {1, 14, 14, 1024};
  int b93866244686992_a_shape[4] = {1, 14, 14, 1024};
  int b93866244686992_b_shape[4] = {1, 14, 14, 1024};
  add_with_broadcasting(
      (float *)b93866244686992, (float *)b93866244484224,
      (float *)b93866486349808, (int *)b93866244686992_out_shape, 4,
      (int *)b93866244686992_a_shape, 4, (int *)b93866244686992_b_shape, 4);

  batchNormInference((float *)b93866244686992, (float *)b93866244922640, 1, 14,
                     14, 1024, stage3_unit4_bn1_gamma, stage3_unit4_bn1_beta,
                     stage3_unit4_bn1_moving_mean, stage3_unit4_bn1_moving_var,
                     0.00002);

  relu((float *)b93866244922640, (float *)b93866244922704, 1, 14, 14, 1024);

  zero_pad_nhwc((float *)b93866244606864prepadding, (float *)b93866244922704,
                14, 14, 1024, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244606864, (float *)b93866244606864prepadding,
      stage3_unit4_conv1_weight, 14, 14, 1, 1, 1024, 256, 1, 1);

  batchNormInference((float *)b93866244606864, (float *)b93866245067808, 1, 14,
                     14, 256, stage3_unit4_bn2_gamma, stage3_unit4_bn2_beta,
                     stage3_unit4_bn2_moving_mean, stage3_unit4_bn2_moving_var,
                     0.00002);

  relu((float *)b93866245067808, (float *)b93866245067872, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866244615232prepadding, (float *)b93866245067872,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244615232, (float *)b93866244615232prepadding,
      stage3_unit4_conv2_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  batchNormInference((float *)b93866244615232, (float *)b93866231272384, 1, 14,
                     14, 256, stage3_unit4_bn3_gamma, stage3_unit4_bn3_beta,
                     stage3_unit4_bn3_moving_mean, stage3_unit4_bn3_moving_var,
                     0.00002);

  relu((float *)b93866231272384, (float *)b93866231272448, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866244633408prepadding, (float *)b93866231272448,
                14, 14, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244633408, (float *)b93866244633408prepadding,
      stage3_unit4_conv3_weight, 14, 14, 1, 1, 256, 1024, 1, 1);

  int b93866177335136_out_shape[4] = {1, 14, 14, 1024};
  int b93866177335136_a_shape[4] = {1, 14, 14, 1024};
  int b93866177335136_b_shape[4] = {1, 14, 14, 1024};
  add_with_broadcasting(
      (float *)b93866177335136, (float *)b93866244633408,
      (float *)b93866244686992, (int *)b93866177335136_out_shape, 4,
      (int *)b93866177335136_a_shape, 4, (int *)b93866177335136_b_shape, 4);

  batchNormInference((float *)b93866177335136, (float *)b93866206601232, 1, 14,
                     14, 1024, stage3_unit5_bn1_gamma, stage3_unit5_bn1_beta,
                     stage3_unit5_bn1_moving_mean, stage3_unit5_bn1_moving_var,
                     0.00002);

  relu((float *)b93866206601232, (float *)b93866231320480, 1, 14, 14, 1024);

  zero_pad_nhwc((float *)b93866194576000prepadding, (float *)b93866231320480,
                14, 14, 1024, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866194576000, (float *)b93866194576000prepadding,
      stage3_unit5_conv1_weight, 14, 14, 1, 1, 1024, 256, 1, 1);

  batchNormInference((float *)b93866194576000, (float *)b93866198383888, 1, 14,
                     14, 256, stage3_unit5_bn2_gamma, stage3_unit5_bn2_beta,
                     stage3_unit5_bn2_moving_mean, stage3_unit5_bn2_moving_var,
                     0.00002);

  relu((float *)b93866198383888, (float *)b93866184230400, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866206633040prepadding, (float *)b93866184230400,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206633040, (float *)b93866206633040prepadding,
      stage3_unit5_conv2_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  batchNormInference((float *)b93866206633040, (float *)b93866206633296, 1, 14,
                     14, 256, stage3_unit5_bn3_gamma, stage3_unit5_bn3_beta,
                     stage3_unit5_bn3_moving_mean, stage3_unit5_bn3_moving_var,
                     0.00002);

  relu((float *)b93866206633296, (float *)b93866206653168, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866206624768prepadding, (float *)b93866206653168,
                14, 14, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206624768, (float *)b93866206624768prepadding,
      stage3_unit5_conv3_weight, 14, 14, 1, 1, 256, 1024, 1, 1);

  int b93866231120320_out_shape[4] = {1, 14, 14, 1024};
  int b93866231120320_a_shape[4] = {1, 14, 14, 1024};
  int b93866231120320_b_shape[4] = {1, 14, 14, 1024};
  add_with_broadcasting(
      (float *)b93866231120320, (float *)b93866206624768,
      (float *)b93866177335136, (int *)b93866231120320_out_shape, 4,
      (int *)b93866231120320_a_shape, 4, (int *)b93866231120320_b_shape, 4);

  batchNormInference((float *)b93866231120320, (float *)b93866244690272, 1, 14,
                     14, 1024, stage3_unit6_bn1_gamma, stage3_unit6_bn1_beta,
                     stage3_unit6_bn1_moving_mean, stage3_unit6_bn1_moving_var,
                     0.00002);

  relu((float *)b93866244690272, (float *)b93866244690336, 1, 14, 14, 1024);

  zero_pad_nhwc((float *)b93866244552384prepadding, (float *)b93866244690336,
                14, 14, 1024, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244552384, (float *)b93866244552384prepadding,
      stage3_unit6_conv1_weight, 14, 14, 1, 1, 1024, 256, 1, 1);

  batchNormInference((float *)b93866244552384, (float *)b93866182786640, 1, 14,
                     14, 256, stage3_unit6_bn2_gamma, stage3_unit6_bn2_beta,
                     stage3_unit6_bn2_moving_mean, stage3_unit6_bn2_moving_var,
                     0.00002);

  relu((float *)b93866182786640, (float *)b93866182786704, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866206585552prepadding, (float *)b93866182786704,
                14, 14, 256, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206585552, (float *)b93866206585552prepadding,
      stage3_unit6_conv2_weight, 16, 16, 3, 3, 256, 256, 1, 1);

  batchNormInference((float *)b93866206585552, (float *)b93866206411376, 1, 14,
                     14, 256, stage3_unit6_bn3_gamma, stage3_unit6_bn3_beta,
                     stage3_unit6_bn3_moving_mean, stage3_unit6_bn3_moving_var,
                     0.00002);

  relu((float *)b93866206411376, (float *)b93866206411440, 1, 14, 14, 256);

  zero_pad_nhwc((float *)b93866244721616prepadding, (float *)b93866206411440,
                14, 14, 256, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244721616, (float *)b93866244721616prepadding,
      stage3_unit6_conv3_weight, 14, 14, 1, 1, 256, 1024, 1, 1);

  int b93866231132816_out_shape[4] = {1, 14, 14, 1024};
  int b93866231132816_a_shape[4] = {1, 14, 14, 1024};
  int b93866231132816_b_shape[4] = {1, 14, 14, 1024};
  add_with_broadcasting(
      (float *)b93866231132816, (float *)b93866244721616,
      (float *)b93866231120320, (int *)b93866231132816_out_shape, 4,
      (int *)b93866231132816_a_shape, 4, (int *)b93866231132816_b_shape, 4);

  batchNormInference((float *)b93866231132816, (float *)b93866342138256, 1, 14,
                     14, 1024, stage4_unit1_bn1_gamma, stage4_unit1_bn1_beta,
                     stage4_unit1_bn1_moving_mean, stage4_unit1_bn1_moving_var,
                     0.00002);

  relu((float *)b93866342138256, (float *)b93866206640368, 1, 14, 14, 1024);

  zero_pad_nhwc((float *)b93866206549072prepadding, (float *)b93866206640368,
                14, 14, 1024, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206549072, (float *)b93866206549072prepadding,
      stage4_unit1_conv1_weight, 14, 14, 1, 1, 1024, 512, 2, 2);

  batchNormInference((float *)b93866206549072, (float *)b93866206598912, 1, 7,
                     7, 512, stage4_unit1_bn2_gamma, stage4_unit1_bn2_beta,
                     stage4_unit1_bn2_moving_mean, stage4_unit1_bn2_moving_var,
                     0.00002);

  relu((float *)b93866206598912, (float *)b93866231447296, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866206498576prepadding, (float *)b93866231447296, 7,
                7, 512, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206498576, (float *)b93866206498576prepadding,
      stage4_unit1_conv2_weight, 9, 9, 3, 3, 512, 512, 1, 1);

  batchNormInference((float *)b93866206498576, (float *)b93866231446320, 1, 7,
                     7, 512, stage4_unit1_bn3_gamma, stage4_unit1_bn3_beta,
                     stage4_unit1_bn3_moving_mean, stage4_unit1_bn3_moving_var,
                     0.00002);

  relu((float *)b93866231446320, (float *)b93866206510752, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866244612880prepadding, (float *)b93866206510752, 7,
                7, 512, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244612880, (float *)b93866244612880prepadding,
      stage4_unit1_conv3_weight, 7, 7, 1, 1, 512, 2048, 1, 1);

  zero_pad_nhwc((float *)b93866206487584prepadding, (float *)b93866206640368,
                14, 14, 1024, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206487584, (float *)b93866206487584prepadding,
      stage4_unit1_sc_weight, 14, 14, 1, 1, 1024, 2048, 2, 2);

  int b93866206491520_out_shape[4] = {1, 7, 7, 2048};
  int b93866206491520_a_shape[4] = {1, 7, 7, 2048};
  int b93866206491520_b_shape[4] = {1, 7, 7, 2048};
  add_with_broadcasting(
      (float *)b93866206491520, (float *)b93866244612880,
      (float *)b93866206487584, (int *)b93866206491520_out_shape, 4,
      (int *)b93866206491520_a_shape, 4, (int *)b93866206491520_b_shape, 4);

  batchNormInference((float *)b93866206491520, (float *)b93866231133408, 1, 7,
                     7, 2048, stage4_unit2_bn1_gamma, stage4_unit2_bn1_beta,
                     stage4_unit2_bn1_moving_mean, stage4_unit2_bn1_moving_var,
                     0.00002);

  relu((float *)b93866231133408, (float *)b93866245073152, 1, 7, 7, 2048);

  zero_pad_nhwc((float *)b93866245455904prepadding, (float *)b93866245073152, 7,
                7, 2048, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866245455904, (float *)b93866245455904prepadding,
      stage4_unit2_conv1_weight, 7, 7, 1, 1, 2048, 512, 1, 1);

  batchNormInference((float *)b93866245455904, (float *)b93866245458544, 1, 7,
                     7, 512, stage4_unit2_bn2_gamma, stage4_unit2_bn2_beta,
                     stage4_unit2_bn2_moving_mean, stage4_unit2_bn2_moving_var,
                     0.00002);

  relu((float *)b93866245458544, (float *)b93866245087440, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866245086736prepadding, (float *)b93866245087440, 7,
                7, 512, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866245086736, (float *)b93866245086736prepadding,
      stage4_unit2_conv2_weight, 9, 9, 3, 3, 512, 512, 1, 1);

  batchNormInference((float *)b93866245086736, (float *)b93866245463600, 1, 7,
                     7, 512, stage4_unit2_bn3_gamma, stage4_unit2_bn3_beta,
                     stage4_unit2_bn3_moving_mean, stage4_unit2_bn3_moving_var,
                     0.00002);

  relu((float *)b93866245463600, (float *)b93866244639888, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866244962896prepadding, (float *)b93866244639888, 7,
                7, 512, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244962896, (float *)b93866244962896prepadding,
      stage4_unit2_conv3_weight, 7, 7, 1, 1, 512, 2048, 1, 1);

  int b93866245466320_out_shape[4] = {1, 7, 7, 2048};
  int b93866245466320_a_shape[4] = {1, 7, 7, 2048};
  int b93866245466320_b_shape[4] = {1, 7, 7, 2048};
  add_with_broadcasting(
      (float *)b93866245466320, (float *)b93866244962896,
      (float *)b93866206491520, (int *)b93866245466320_out_shape, 4,
      (int *)b93866245466320_a_shape, 4, (int *)b93866245466320_b_shape, 4);

  batchNormInference((float *)b93866245466320, (float *)b93866231234240, 1, 7,
                     7, 2048, stage4_unit3_bn1_gamma, stage4_unit3_bn1_beta,
                     stage4_unit3_bn1_moving_mean, stage4_unit3_bn1_moving_var,
                     0.00002);

  relu((float *)b93866231234240, (float *)b93866231234304, 1, 7, 7, 2048);

  zero_pad_nhwc((float *)b93866231202304prepadding, (float *)b93866231234304, 7,
                7, 2048, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866231202304, (float *)b93866231202304prepadding,
      stage4_unit3_conv1_weight, 7, 7, 1, 1, 2048, 512, 1, 1);

  batchNormInference((float *)b93866231202304, (float *)b93866245478848, 1, 7,
                     7, 512, stage4_unit3_bn2_gamma, stage4_unit3_bn2_beta,
                     stage4_unit3_bn2_moving_mean, stage4_unit3_bn2_moving_var,
                     0.00002);

  relu((float *)b93866245478848, (float *)b93866245478912, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866244748064prepadding, (float *)b93866245478912, 7,
                7, 512, 1, 1, 1, 1);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866244748064, (float *)b93866244748064prepadding,
      stage4_unit3_conv2_weight, 9, 9, 3, 3, 512, 512, 1, 1);

  batchNormInference((float *)b93866244748064, (float *)b93866231393808, 1, 7,
                     7, 512, stage4_unit3_bn3_gamma, stage4_unit3_bn3_beta,
                     stage4_unit3_bn3_moving_mean, stage4_unit3_bn3_moving_var,
                     0.00002);

  relu((float *)b93866231393808, (float *)b93866231393872, 1, 7, 7, 512);

  zero_pad_nhwc((float *)b93866206558192prepadding, (float *)b93866231393872, 7,
                7, 512, 0, 0, 0, 0);

  rtml_systolic_array_weight_stationary_conv2d_nhwc_hwio_prepadded(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866206558192, (float *)b93866206558192prepadding,
      stage4_unit3_conv3_weight, 7, 7, 1, 1, 512, 2048, 1, 1);

  int b93866231353280_out_shape[4] = {1, 7, 7, 2048};
  int b93866231353280_a_shape[4] = {1, 7, 7, 2048};
  int b93866231353280_b_shape[4] = {1, 7, 7, 2048};
  add_with_broadcasting(
      (float *)b93866231353280, (float *)b93866206558192,
      (float *)b93866245466320, (int *)b93866231353280_out_shape, 4,
      (int *)b93866231353280_a_shape, 4, (int *)b93866231353280_b_shape, 4);

  batchNormInference((float *)b93866231353280, (float *)b93866206452160, 1, 7,
                     7, 2048, bn1_gamma, bn1_beta, bn1_moving_mean,
                     bn1_moving_var, 0.00002);

  relu((float *)b93866206452160, (float *)b93866231107136, 1, 7, 7, 2048);

  globalAvgPool((float *)b93866231107136, (float *)b93866231414832, 1, 7, 7,
                2048);

  rtml_systolic_array_weight_stationary_fc(
      0, // hardware id hardwired to 0 for monolithic case
      (float *)b93866197774320, (float *)b93866231414832, fc1_weight, 2048,
      1000, 1);

  int b93866244499936_out_shape[2] = {1, 1000};
  int b93866244499936_a_shape[2] = {1, 1000};
  int b93866244499936_b_shape[1] = {1000};
  add_with_broadcasting((float *)b93866244499936, (float *)b93866197774320,
                        (float *)fc1_bias, (int *)b93866244499936_out_shape, 2,
                        (int *)b93866244499936_a_shape, 2,
                        (int *)b93866244499936_b_shape, 1);

  softmax1D((float *)b93866244499936, (float *)out, 1000);
}