Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Broadcasting ops are slow #8219

Open
aseyboldt opened this issue Oct 11, 2017 · 25 comments
Open

Broadcasting ops are slow #8219

aseyboldt opened this issue Oct 11, 2017 · 25 comments

Comments

@aseyboldt
Copy link

I can't see a good reason why broadcast_add(array, value) should be much slower than array + scalar, but the speed difference is almost 100x:

import mxnet as mx

a = mx.sym.var('a')
b = mx.sym.var('b')

a_ = mx.nd.ones((2**20,))
b1 = mx.nd.ones((2**20,))
b2 = mx.nd.ones((1,))

func1 = (a + b).bind(mx.cpu(), {'a': a_, 'b': b1})
func2 = mx.sym.broadcast_add(a, b).bind(mx.cpu(), {'a': a_, 'b': b2})
func3 = (a + 1.).bind(mx.cpu(), {'a': a_})

# array + array
%timeit func1.forward()[0].wait_to_read()
# 409 µs ± 12.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

# boadcast_add(scalar, array)
%timeit func2.forward()[0].wait_to_read()
# 7.02 ms ± 7.08 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# array + scalar
%timeit func3.forward()[0].wait_to_read()
# 88.7 µs ± 939 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

A bit of digging led me to this function:

https://github.com/apache/incubator-mxnet/blob/bbf1c0b9d49b432fc46b911656f7f3fdd2521f98/src/operator/tensor/broadcast_reduce-inl.h#L142-L152

This is called for each element of the result array, and does a lot of index juggling. Broadcasting can be implemented much faster by using strides of 0 in dimensions where we want to broadcast. Additionally it is possible to optimize some inner loops further in many cases. The best reference I know of would be the numpy iterator api https://docs.scipy.org/doc/numpy-dev/reference/c-api.iterator.html

The same problem seem to be in the reduction ops in the same file.

For the benchmark I used git commit 573a010 on a linux machine with intel cpu.

@piiswrong
Copy link
Contributor

100x is a little surprising. The code was written like this to be consistent between CPU and GPU. If the performance difference is really that big we can consider writing a separate implementation for CPU.

Would you like to propose a fix?

@aseyboldt
Copy link
Author

Turns out part of the difference is because the broadcasting version is single-threaded. If I set the number of threads to 1, the difference is a bit smaller (~20x):

824 µs ± 44.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
7.01 ms ± 4.67 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
316 µs ± 850 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Looking at the assembly in perf this doesn't seem all that surprising. (array+scalar) are nice linear simd instructions, the broadcasting version is a mess involving lots of divs and imults. I'm not sure how prefetching and pipelining are affected...

Sadly, I don't think I'll have the time to look into this in detail the next couple of days or probably weeks.

@piiswrong
Copy link
Contributor

Assigning Chris to keep track of this.

We'll probably be able to get to this in a couple weeks.

@cjolivier01
Copy link
Member

In case I get a chance to look at it, I’d like to know what to profile. What’s your model and data set? What are your tensor dimensions? I assume float32 type...

@cjolivier01
Copy link
Member

Oh, heh there’s the python code... missed that on my phone’s tiny screen.

@cjolivier01
Copy link
Member

By the way, often things are faster with OMP_NUM_THREADS=1 when the tensors aren’t huge and the calculations are relatively simple. I assume there’s an OMP loop outside of the kernel you showed. L

@aseyboldt
Copy link
Author

@cjolivier01 I didn't see any openmp stuff in the broadcast implementation, I don't think it is multithreaded atm. Only the plain add and add_scalar Ops are.

@cjolivier01
Copy link
Member

@aseyboldt Which code version?

@cjolivier01
Copy link
Member

cjolivier01 commented Dec 6, 2017

By the way, I ran this script several times:

import mxnet as mx
import time

a = mx.sym.var('a')
b = mx.sym.var('b')

a_ = mx.nd.ones((2**20,))
b1 = mx.nd.ones((2**20,))
b2 = mx.nd.ones((1,))

func1 = (a + b).bind(mx.cpu(), {'a': a_, 'b': b1})
func2 = mx.sym.broadcast_add(a, b).bind(mx.cpu(), {'a': a_, 'b': b2})
func3 = (a + 1.).bind(mx.cpu(), {'a': a_})

for x in range(4):
  print("PASS: {}...", x)

  # array + array
  start = time.time()
  for i in range(1000):
    func1.forward()[0].wait_to_read()
  print("func1: {}".format(time.time() - start))


  # boadcast_add(scalar, array)
  start = time.time()
  for i in range(1000):
    func2.forward()[0].wait_to_read()
  print("func2: {}".format(time.time() - start))


  # array + scalar
  start = time.time()
  for i in range(1000):
    func3.forward()[0].wait_to_read()
  print("func3: {}".format(time.time() - start))

  print(" ")

@cjolivier01
Copy link
Member

I get this output:
('PASS: {}...', 0)
func1: 0.298206090927
func2: 0.741135120392
func3: 0.212859153748

('PASS: {}...', 1)
func1: 0.148470878601
func2: 0.578938007355
func3: 0.126513957977

('PASS: {}...', 2)
func1: 0.137872934341
func2: 0.518090963364
func3: 0.144251823425

('PASS: {}...', 3)
func1: 0.17519402504
func2: 0.5326359272
func3: 0.159995079041

@cjolivier01
Copy link
Member

broadcast_add has more overhead than a simple elemwise_add, so I would expect it to be a bit slower. I believe the current implementation of "+" uses elemwise_add if a and b shapes are the same (a + b portion), and broadcast_add() if they differ.

@cjolivier01
Copy link
Member

it is not at all surprising that scalar add 9the third) is the fastest. It's fewer mem reads, far less overhead in general.

@cjolivier01
Copy link
Member

note that @piiswrong refactored broadcast a few weeks ago...

@aseyboldt
Copy link
Author

@cjolivier01 Thanks for looking into this. 🙂

I haven't updated to mxnet 1.0 yet, so it is possible that this is fixed now (I only have slow internet at the moment so I can't update). Looking at the code I don't think so however.

The broadcasting array + array shouldn't be much slower than plain array + array, especially if one of the arrays is smaller than the other, as that helps a lot with the memory bandwidth. Memory bandwidth should be the limiting factor in simple ops on large arrays. This can be seen when we compare to numpy:

import os
os.environ['OMP_NUM_THREADS'] = '1'

import numpy as np
import mxnet as mx
import time

a = mx.sym.var('a')
b = mx.sym.var('b')

a_ = mx.nd.ones((2**17, 10, 10))
b_ = mx.nd.ones((1,))
c_ = a_.copy()

x = a_.asnumpy()
y = b_.asnumpy()
z = c_.asnumpy()

func1 = (a + b).bind(mx.cpu(), {'a': a_, 'b': c_})
func2 = mx.sym.broadcast_add(a, b).bind(mx.cpu(), {'a': a_, 'b': b_})

for _ in range(2):
    # elemwise
    start = time.time()
    for i in range(100):
        func1.forward()[0].wait_to_read()
    print("func1: {}".format(time.time() - start))


    # boadcast_add(array, array)
    start = time.time()
    for i in range(100):
        func2.forward()[0].wait_to_read()
    print("func2: {}".format(time.time() - start))

    # numpy elemwise
    start = time.time()
    out = np.zeros_like(x)
    for i in range(100):
        np.add(x, z, out=out)
    print("numpy1: {}".format(time.time() - start))
    
    # numpy broadcast
    start = time.time()
    for i in range(100):
        np.add(x, y, out=out)
    print("numpy2: {}".format(time.time() - start))
    
    print()

which gives me (different machine than the last benchmark)

func1: 0.9796142578125
func2: 9.832738876342773
numpy1: 0.9367139339447021
numpy2: 0.6408178806304932

func1: 0.927008867263794
func2: 10.026437997817993
numpy1: 1.091845989227295
numpy2: 0.646554708480835

For numpy the broadcasting op is faster than the normal one, for mxnet it is 10x slower.

In the non-broadcasting case both numpy and mxnet are bound by memory bandwidth, and this is still more or less the case in the broadcasting case for numpy, but not for mxnet. This seems to happen in general for the broadcasting ops in mxnet, not only when a scalar is added. (Although numpy can't use up all the memory bandwidth in some cases either, it never slows down nearly as much as mxnet)

My guess as to why func2 is so much slower than func1 is that the index juggling in ravel and unravel takes time and messes up prefetching. Other explanations could be that maybe some array is traversed in the wrong order (but I don't think this is the case) or that the branch because of addto slows things down (but I don't see how that would be a factor of 10).

@sxjscience
Copy link
Member

I've run the codes by @aseyboldt again using a 2.xlarge + the current master and the result is like this:

func1: 1.5421907901763916
func2: 3.8357701301574707
numpy1: 1.530024528503418
numpy2: 1.0844464302062988

func1: 1.5544826984405518
func2: 3.8282980918884277
numpy1: 1.5478718280792236
numpy2: 1.0857605934143066

@sandeep-krishnamurthy
Copy link
Contributor

@anirudh2290
Copy link
Member

anirudh2290 commented May 1, 2018

Taking a look now. I tried to run the script provided by @cjolivier01 , on the latest master: 97da5e3. I was able to see a 4x slower speed for broadcast_ops over a range of 1000 runs, not 100x. I am running on a p2.8xlarge.

@sandeep-krishnamurthy
Copy link
Contributor

Thanks @anirudh2290 for looking into this. This will be very beneficial.

@cjolivier01 cjolivier01 removed their assignment May 1, 2018
@anirudh2290
Copy link
Member

@anirudh2290
Copy link
Member

anirudh2290 commented Jun 9, 2018

I have been looking at this issue. MXNet forward pass for broadcast_add is much faster than tensorflow and pytorch. To give some numbers here (Experiments conducted on my setup of p2.8xlarge, Testing for CPU performance):

For broadcasting a tensor of shape (1,) to a tensor of shape (2**17, 10, 10), only forward pass:
pytorch: 0.6 seconds
mxnet: 0.4 seconds
tensorflow: 2.1 seconds.

When we include both forward and backward pass:
pytorch: 5.1 seconds
mxnet: 16 seconds
tensorflow: 2.2 seconds

So we decide to look at the MXNet backward pass and try out some optimizations. We try out using LaunchEx so that each thread gets a bigger chunk of workload. This by itself doesn't help. The bottleneck for the backward pass is the for loop that runs for each thread: https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/broadcast_reduce-inl.h#L164

There are a lot of repeated computations for computing coords and dot inside the for loop. We try to cache this computation. This improves the speed more than 1.4X to around 12 seconds for mxnet. This involves extra memory which is the drawback. You can see the rough implementation here: https://github.com/anirudh2290/mxnet/blob/cached_broadcast/src/operator/tensor/broadcast_reduce-inl.h#L233

We suspect (not yet confirmed) that Tensorflow uses the eigen library to do the reduce. This is something that we should run experiments for and consider as a replacement as we deprecate mshadow.

Next steps:

  1. Introduce the cached solution to MXNet as a stop gap solution.
  2. Investigate eigen library and whether it will be worthwhile to add it as a MXNet dependency.

Here is the script. Please let me know if there is any caveat that I have missed:

import numpy as np
import mxnet as mx

a = mx.sym.var('a')
b = mx.sym.var('b')

a_ = mx.nd.ones((2**17, 10, 10))
b_ = mx.nd.ones((1,))

func2 = mx.sym.broadcast_add(a, b).bind(mx.cpu(), args={'a': a_, 'b': b_}, args_grad = {'a': mx.nd.ones((2**17, 10, 10)), 'b': mx.nd.ones((1))})

for _ in range(4):
    # boadcast_add(array, array)
    start = time.time()
    for i in range(100):
        out = func2.forward(is_train=True)[0]
        func2.backward(mx.nd.ones((2**17, 10, 10)))
    mx.nd.waitall()
    print("mxnet time taken is: {}".format(time.time() - start))

import torch

for i in range(4):

    start = time.time()
    for j in range(100):
        x = torch.ones((2**17, 10, 10), requires_grad=True)
        y = torch.ones((1), requires_grad=True)
        z = x + y
        z.backward(torch.ones((2**17, 10, 10)), retain_graph=True)
    print("torch time taken is {}".format(time.time() - start))

import tensorflow as tf
a = tf.ones([2**17, 10, 10], name='a')
b = tf.ones([1], name='b')
add_op = a + b
g = tf.gradients(add_op, [a,b])

for x in range(4):
    with tf.Session() as session:
        start = time.time()
        for i in range(1, 100):
            grad_vals = session.run(g)
        print("tf time taken is: {}".format(time.time() - start))

@piiswrong @andreaolgiati @srochel

@lupesko
Copy link
Contributor

lupesko commented Oct 10, 2018

@pengzhao-intel is this something you guys can help with?

@pengzhao-intel
Copy link
Contributor

OK, will take a look for this issue.

@pengzhao-intel
Copy link
Contributor

@rongzha1 will help for this issue

@ChaiBapchya
Copy link
Contributor

Curious to know if this is still being worked on? @pengzhao-intel @anirudh2290

@anirudh2290
Copy link
Member

Introduce the cached solution to MXNet as a stop gap solution.

#11252 this PR addressed this solution.

For 2:

Investigate eigen library and whether it will be worthwhile to add it as a MXNet dependency.

This may not be a trivial effort and requires a detailed performance analysis for mshadow and with eigen for the common use cases. Also, this affect only ops not supported by MKLDNN and the benefits may not be worth the effort.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

10 participants