diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 8b2b6dfd2d7..c81944c1205 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -14,7 +14,11 @@ #include "ir.h" +#include +#include +#include #include +#include #include #include #include @@ -22,6 +26,7 @@ #include #include #include +#include #include "storezip.h" #include "utils.h" @@ -1441,6 +1446,486 @@ static std::string make_index_expression(const Operator* op) return index_expr; } +void Graph::flops_memops_sum() +{ + for (auto op : ops) + { + if (op->type[0] == 'F') + { + std::string sub_type = op->type.substr(2); + if (sub_type == "linear") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int out_features = op->attrs.at("data").shape[0]; + flops += input_size * out_features; + if (op->has_param("bias")) + { + flops += out_features; + } + memops += input_size + output_size; + } + else if (sub_type == "avgpool1d" + || sub_type == "avgpool2d" + || sub_type == "avgpool3d" + || sub_type == "adaptive_avgpool1d" + || sub_type == "adaptive_avgpool2d" + || sub_type == "adaptive_avgpool3d") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size; + memops += input_size + output_size; + } + else if (sub_type == "prelu" + || sub_type == "elu" + || sub_type == "leaky_relu" + || sub_type == "gelu" + || sub_type == "silu" + || sub_type == "softmax") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + extra_flops += input_size; + extra_memops += input_size + output_size; + } + else if (sub_type == "unsample" + || sub_type == "upsample_nearest" + || sub_type == "upsample_bilinear") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + extra_flops += output_size; + extra_memops += input_size + output_size; + } + else if (sub_type == "interpolate") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + std::vector scale_factor = op->params.at("scale_factor").ai; + extra_flops += input_size * std::accumulate(scale_factor.begin(), scale_factor.end(), 1, std::multiplies()); + extra_memops += input_size + output_size; + } + } + + else if (op->type.substr(0, 2) == "nn") + { + std::string sub_type = op->type.substr(3); + if (sub_type == "BatchNorm1d" + || sub_type == "BatchNorm2d" + || sub_type == "BatchNorm3d" + || sub_type == "GroupNorm" + || sub_type == "LayerNorm" + || sub_type == "InstanceNorm1d" + || sub_type == "InstanceNorm2d" + || sub_type == "InstanceNorm3d") + { + std::vector shape = op->inputs[0]->shape; + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + if ((op->has_param("affine") && op->params.at("affine").b) + || (op->has_param("elementwise_affine") && op->params.at("elementwise_affine").b)) + { + extra_flops += 2 * num_elements; + extra_memops += 2 * (num_elements + n * c); + } + else + { + extra_flops += num_elements; + extra_memops += num_elements; + } + } + else if (sub_type == "Conv1d" + || sub_type == "Conv2d" + || sub_type == "Conv3d" + || sub_type == "ConvTranspose1d" + || sub_type == "ConvTranspose2d" + || sub_type == "ConvTranspose3d") + { + int c = op->params.at("in_channels").i; + std::vector k = op->params.at("kernel_size").ai; + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int g = op->params["groups"].i; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int kernel_size = std::accumulate(k.begin() + 2, k.end(), 1, std::multiplies()); + flops += output_size * c * kernel_size / g; + memops += input_size + output_size + std::accumulate(k.begin(), k.end(), 1, std::multiplies()) * c / g; + if (op->has_param("bias")) + { + flops += output_size; + memops += output_size; + } + } + else if (sub_type == "AvgPool1d" + || sub_type == "AvgPool2d" + || sub_type == "AvgPool3d") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size; + memops += input_size + output_size; + } + else if (sub_type == "AdaptiveAvgPool1d" + || sub_type == "AdaptiveAvgPool2d" + || sub_type == "AdaptiveAvgPool3d") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + std::vector kernel_size; + for (size_t i = 2; i < input_shape.size(); i++) + { + kernel_size.emplace_back(output_shape[i] / input_shape[i]); + } + flops += (std::accumulate(kernel_size.begin(), kernel_size.end(), 1, std::multiplies()) + 1) * output_size; + memops += input_size + output_size; + } + else if (sub_type == "PReLU" + || sub_type == "ELU" + || sub_type == "LeakyReLU" + || sub_type == "GELU") + { + std::vector shape = op->outputs[0]->shape; + int n = shape[0]; + int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + extra_flops += num_elements; + if (sub_type == "PReLU") + { + extra_memops += 2 * num_elements + n * op->params["num_parameters"].i; + } + else + { + extra_memops += 2 * num_elements; + } + } + else if (sub_type == "Tanh") + { + std::vector shape = op->outputs[0]->shape; + int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + extra_flops += 2 * num_elements; + extra_memops += 2 * num_elements; + } + else if (sub_type == "Linear") + { + std::vector input_shape = op->inputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int in_features = op->params.at("in_features").i; + int out_features = op->params.at("out_features").i; + int bias = op->has_param("bias") ? out_features : 0; + flops += (in_features * out_features + bias) * input_size / in_features; + memops += input_size + output_size + output_size * (bias ? 1 : 0); + } + else if (sub_type == "Upsample" + || sub_type == "UnsampleBilinear2d" + || sub_type == "UnsampleNearest2d") + { + std::vector input_shape = op->inputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + std::string mode; + if (sub_type == "Unsample") + { + mode = op->has_param("mode") ? op->params.at("mode").s : "nearest"; + } + else if (sub_type == "UnsampleBilinear2d") + { + mode = "bilinear"; + } + else if (sub_type == "UnsampleNearest2d") + { + mode = "nearest"; + } + + if (mode == "nearest") + { + extra_flops += input_size; + extra_memops += input_size + output_size; + } + else if (mode == "linear") + { + extra_flops += 5 * output_size; + extra_memops += 2 * input_size + output_size; + } + else if (mode == "bilinear") + { + extra_flops += 11 * output_size; + extra_memops += 4 * input_size + output_size; + } + else if (mode == "bicubic") + { + extra_flops += (224 + 35) * output_size; + extra_memops += 16 * input_size + output_size; + } + else if (mode == "trilinear") + { + extra_flops += (13 * 2 + 5) * input_size; + extra_memops += 8 * input_size + output_size; + } + } + else if (sub_type == "RNN") + { + bool bi = op->has_param("bidirectional") && op->params.at("bidirectional").b; + bool bias = op->has_param("bias") && op->params.at("bias").b; + int input_size = op->params.at("input_size").i; + int hidden_size = op->params.at("hidden_size").i; + int flops1 = hidden_size * (input_size + hidden_size) + hidden_size; + if (bias) + { + flops1 += 2 * hidden_size; + } + if (bi) + { + flops1 *= 2; + } + + int num_layers = op->params.at("num_layers").i; + int flops2 = 0; + if (bi) + { + flops2 = 3 * hidden_size * hidden_size + hidden_size; + if (bias) + { + flops2 += 2 * hidden_size; + } + flops2 *= 2 * num_layers; + } + else + { + flops2 = 2 * hidden_size * hidden_size + hidden_size; + if (bias) + { + flops2 += 2 * hidden_size; + } + flops2 *= num_layers; + } + bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; + int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; + int num_steps = batch_first ? op->inputs[0]->shape[1] : op->inputs[0]->shape[0]; + flops += (flops1 + flops2) * num_steps * batch_size; + memops += num_steps * batch_size * input_size; + memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); + if (bias) + { + memops += 2 * hidden_size * num_layers * (bi ? 2 : 1); + } + } + else if (sub_type == "LSTM") + { + bool bi = op->has_param("bidirectional") && op->params.at("bidirectional").b; + bool bias = op->has_param("bias") && op->params.at("bias").b; + int input_size = op->params.at("input_size").i; + int hidden_size = op->params.at("hidden_size").i; + int flops1 = 4 * hidden_size * (input_size + hidden_size) + 4 * hidden_size; + if (bias) + { + flops1 += 8 * hidden_size; + } + if (bi) + { + flops1 *= 2; + } + flops1 += 4 * hidden_size; + + int num_layers = op->params.at("num_layers").i; + int flops2 = 0; + if (bi) + { + flops2 = 12 * hidden_size * hidden_size + 4 * hidden_size; + if (bias) + { + flops2 += 8 * hidden_size; + } + flops2 += 4 * hidden_size; + flops2 *= 2 * num_layers; + } + else + { + flops2 = 4 * hidden_size * hidden_size + 4 * hidden_size; + if (bias) + { + flops2 += 8 * hidden_size; + } + flops2 += 4 * hidden_size; + flops2 *= num_layers; + } + bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; + int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; + int num_steps = batch_first ? op->inputs[0]->shape[1] : op->inputs[0]->shape[0]; + flops += (flops1 + flops2) * num_steps * batch_size; + memops += num_steps * batch_size * input_size; + memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); + if (bias) + { + memops += 8 * hidden_size * num_layers * (bi ? 2 : 1); + } + } + else if (sub_type == "GRU") + { + bool bi = op->has_param("bidirectional") && op->params.at("bidirectional").b; + bool bias = op->has_param("bias") && op->params.at("bias").b; + int input_size = op->params.at("input_size").i; + int hidden_size = op->params.at("hidden_size").i; + int flops1 = 3 * hidden_size * (input_size + hidden_size) + 3 * hidden_size; + if (bias) + { + flops1 += 6 * hidden_size; + } + flops1 += 4 * hidden_size; + if (bi) + { + flops1 *= 2; + } + + int num_layers = op->params.at("num_layers").i; + int flops2 = 0; + if (bi) + { + flops2 = 9 * hidden_size * hidden_size + 3 * hidden_size; + if (bias) + { + flops2 += 6 * hidden_size; + } + flops2 += 4 * hidden_size; + flops2 *= 2 * num_layers; + } + else + { + flops2 = 6 * hidden_size * hidden_size + 3 * hidden_size; + if (bias) + { + flops2 += 6 * hidden_size; + } + flops2 += 4 * hidden_size; + flops2 *= num_layers; + } + bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; + int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; + int num_steps = batch_first ? op->inputs[0]->shape[1] : op->inputs[0]->shape[0]; + flops += (flops1 + flops2) * num_steps * batch_size; + memops += num_steps * batch_size * input_size; + memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); + if (bias) + { + memops += 6 * hidden_size * num_layers * (bi ? 2 : 1); + } + } + else if (sub_type == "MultiheadAttention") + { + bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; + int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; + int qlen = batch_first ? op->inputs[0]->shape[1] : op->inputs[0]->shape[0]; + int klen = batch_first ? op->inputs[1]->shape[1] : op->inputs[1]->shape[0]; + int d_model = op->params.at("embed_dim").i; + int num_heads = op->params.at("num_heads").i; + int head_dim = d_model / num_heads; + bool bias = op->params.at("bias").b; + + // Linear transformations for Q, K, V + int flops_qkv = 3 * batch_size * qlen * d_model * d_model; + if (bias) + { + flops_qkv += 3 * batch_size * qlen * d_model; + } + + // Scaled dot-product attention + int flops_attention = batch_size * num_heads * qlen * klen * head_dim; + + // Linear transformation for output + int flops_output = batch_size * qlen * d_model * d_model; + if (bias) + { + flops_output += batch_size * qlen * d_model; + } + + flops += flops_qkv + flops_attention + flops_output; + + // Memory operations for Q, K, V + int memops_qkv = 3 * batch_size * qlen * d_model; + if (bias) + { + memops_qkv += 3 * d_model; + } + + // Memory operations for attention weights + int memops_attention = batch_size * num_heads * qlen * klen; + + // Memory operations for output + int memops_output = batch_size * qlen * d_model; + if (bias) + { + memops_output += d_model; + } + + // Total memory operations + memops += memops_qkv + memops_attention + memops_output; + } + } + + else if (op->type.substr(0, 5) == "torch") + { + std::string sub_type = op->type.substr(6); + if (sub_type == "matmul" + || sub_type == "mm" + || sub_type == "bmm") + { + std::vector input_shape_1 = op->inputs[0]->shape; + std::vector input_shape_2 = op->inputs[1]->shape; + int input_size_1 = std::accumulate(input_shape_1.begin(), input_shape_1.end(), 1, std::multiplies()); + int input_size_2 = std::accumulate(input_shape_2.begin(), input_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size_1 * input_shape_2.back(); + memops += input_size_1 + input_size_2 + output_size; + } + else if (sub_type == "addmm" + || sub_type == "baddbmm") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector mat_shape_1 = op->inputs[1]->shape; + std::vector mat_shape_2 = op->inputs[2]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int mat_size_1 = std::accumulate(mat_shape_1.begin(), mat_shape_1.end(), 1, std::multiplies()); + int mat_size_2 = std::accumulate(mat_shape_2.begin(), mat_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size + mat_size_1 * mat_shape_2.back(); + memops += input_size + mat_size_1 + mat_size_2 + output_size; + } + else if (sub_type == "mul" + || sub_type == "add") + { + std::vector input_shape_1 = op->inputs[0]->shape; + std::vector input_shape_2 = op->inputs[1]->shape; + int input_size_1 = std::accumulate(input_shape_1.begin(), input_shape_1.end(), 1, std::multiplies()); + int input_size_2 = std::accumulate(input_shape_2.begin(), input_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += output_size; + memops += input_size_1 + input_size_2 + output_size; + } + } + } +} + int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { FILE* pyfp = fopen(pypath.c_str(), "wb"); @@ -1532,10 +2017,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) for (size_t i = 0; i < param.ai.size(); i++) { if ((op->type == "nn.AdaptiveAvgPool2d" - || op->type == "nn.AdaptiveAvgPool3d" - || op->type == "nn.AdaptiveMaxPool2d" - || op->type == "nn.AdaptiveMaxPool3d") - && it.first == "output_size" && param.ai[i] == 0) + || op->type == "nn.AdaptiveAvgPool3d" + || op->type == "nn.AdaptiveMaxPool2d" + || op->type == "nn.AdaptiveMaxPool3d") + && it.first == "output_size" && param.ai[i] == 0) { fprintf(pyfp, "None"); } @@ -2288,10 +2773,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) for (size_t i = 0; i < param.ai.size(); i++) { if ((op->type == "F.adaptive_avg_pool2d" - || op->type == "F.adaptive_avg_pool3d" - || op->type == "F.adaptive_max_pool2d" - || op->type == "F.adaptive_max_pool3d") - && it.first == "output_size" && param.ai[i] == 0) + || op->type == "F.adaptive_avg_pool3d" + || op->type == "F.adaptive_max_pool2d" + || op->type == "F.adaptive_max_pool3d") + && it.first == "output_size" && param.ai[i] == 0) { fprintf(pyfp, "None"); } diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 779c2eec9f1..37ee81e0a6b 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -346,6 +346,12 @@ class Graph std::vector ops; std::vector operands; + unsigned long long flops = 0; + unsigned long long memops = 0; + unsigned long long extra_flops = 0; + unsigned long long extra_memops = 0; + void flops_memops_sum(); + private: Graph(const Graph& rhs); Graph& operator=(const Graph& rhs); diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index c25128032d9..a50ca679fbc 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -363,6 +363,12 @@ int main(int argc, char** argv) pnnx_graph.python(pnnxpypath, pnnxbinpath); + pnnx_graph.flops_memops_sum(); + fprintf(stderr, "float ops = %.3fM\n", double(pnnx_graph.flops) / 1e6); + fprintf(stderr, "mem ops = %.3fM\n", double(pnnx_graph.memops) / 1e6); + fprintf(stderr, "extra float ops = %.3fM\n", double(pnnx_graph.extra_flops) / 1e6); + fprintf(stderr, "extra mem ops = %.3fM\n", double(pnnx_graph.extra_memops) / 1e6); + #if BUILD_PNNX2ONNX pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); #else @@ -382,6 +388,5 @@ int main(int argc, char** argv) // pnnx_graph2.load("pnnx.param", "pnnx.bin"); // pnnx_graph2.save("pnnx2.param", "pnnx2.bin"); - return 0; }