Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cntk batch normalisation layers problem #396

Closed
jeandebleau opened this issue Sep 2, 2018 · 4 comments
Closed

cntk batch normalisation layers problem #396

jeandebleau opened this issue Sep 2, 2018 · 4 comments

Comments

@jeandebleau
Copy link

I converted successfully several keras pretrained models to cntk.
I used the kit_imagenet to do that.

My issue is that I would like to use these models for transfer learning. I noticed that the batch normalisation layers from keras are not converted to a real batch normalisation layer in cntk. This makes the model almost not usable for transfer learning, at least if one wants to retrain the BN layers.

this is the code from the kit_imagenet.py:

def batch_normalization(input, name, epsilon, **kwargs):
    mean = cntk.Parameter(init = __weights_dict[name]['mean'],
        name = name + "_mean")
    var = cntk.Parameter(init = __weights_dict[name]['var'],
        name = name + "_var")

    layer = (input - mean) / cntk.sqrt(var + epsilon)
    if 'scale' in __weights_dict[name]:
        scale = cntk.Parameter(init = __weights_dict[name]['scale'],
            name = name + "_scale")
        layer = scale * layer

    if 'bias' in __weights_dict[name]:
        bias = cntk.Parameter(init = __weights_dict[name]['bias'],
            name = name + "_bias")
        layer = layer + bias

    return layer

it should be replaced by something like that:

def batch_normalization(input, name, epsilon, **kwargs):

    layer = cntk.layers.BatchNormalization( map_rank = 1, name=name )(input)
 
    mean = cntk.Parameter(init = __weights_dict[name]['mean'],
        name = name + "_mean")
 
    layer.aggregate_mean = mean

    var = cntk.Parameter(init = __weights_dict[name]['var'],
        name = name + "_var")
 
    layer.aggregate_variance = var
    layer.aggregate_count    = 4096.0
 
    if 'scale' in __weights_dict[name]:
        scale = cntk.Parameter(init = __weights_dict[name]['scale'],
            name = name + "_scale")
        layer.scale = scale

    if 'bias' in __weights_dict[name]:
        bias = cntk.Parameter(init = __weights_dict[name]['bias'],
            name = name + "_bias")

        layer.bias = bias

    return layer

is it right ?

@namizzz
Copy link
Contributor

namizzz commented Sep 3, 2018

Hi @jeandebleau , thanks! I tried this code snippet ,but I only got [nan nan nan ... nan nan nan]
as output of the converted cntk model file. Have you tested the converted code?

@jeandebleau
Copy link
Author

jeandebleau commented Sep 3, 2018

Hi, yes I tested the code and manage to make it work as intended. The problem is that with the current conversion, you will lose the batch normalisation layers.

Let me give you more detail.

I first load a model from keras and write its weights and architecture:

BackEndModel = applications.resnet50(input_shape=(224,224,3), include_top=False, weights='imagenet')

model_json = BackEndModel.to_json()
with open("ResNet50.json", "w") as json_file:
    json_file.write(model_json)

I convert it to an intermediate representation:
python -m mmdnn.conversion._script.convertToIR -f keras -d ./kit_imagenet -n ResNet50.json -w ResNet50.h5

and create the file kit_imagenet.py

python -m mmdnn.conversion._script.IRToCode --dstFramework cntk --IRModelPath kit_imagenet.pb --dstModelPath kit_imagenet.py --IRWeightPath kit_imagenet.npy

In the generated kit_imagenet.py file you will find this piece of code:

def batch_normalization(input, name, epsilon, **kwargs):
    mean = cntk.Parameter(init = __weights_dict[name]['mean'],
        name = name + "_mean")
    var = cntk.Parameter(init = __weights_dict[name]['var'],
        name = name + "_var")

    layer = (input - mean) / cntk.sqrt(var + epsilon)
    if 'scale' in __weights_dict[name]:
        scale = cntk.Parameter(init = __weights_dict[name]['scale'],
            name = name + "_scale")
        layer = scale * layer

    if 'bias' in __weights_dict[name]:
        bias = cntk.Parameter(init = __weights_dict[name]['bias'],
            name = name + "_bias")
        layer = layer + bias

    return layer

To get the cntk model, you finally call:
python -m mmdnn.conversion.examples.cntk.imagenet_test -n kit_imagenet.py -w kit_imagenet.npy --dump resnet50.cntk.dnn

If you want the BN layers to be trainable, you need to modify the code to insert a cntk batch normalisation layer. Otherwise the model will not have bn layers and retraining it could be slow or impossible.

@namizzz
Copy link
Contributor

namizzz commented Sep 4, 2018

Hi @jeandebleau , yes! I know what is your meaning, after conversion I also can get xxx.dnn cntk model file. But I only got [nan nan nan ... nan nan nan] as output of the converted cntk model file when I test the accuracy of theis conversion and input a image. You can refer to this file to get the result and intermediate result to check the accuracy. It's my result:

your code :
('n15075141 toilet tissue, toilet paper, bathroom tissue', 999, nan)
('n02319095 sea urchin', 328, nan)
('n02395406 hog, pig, grunter, squealer, Sus scrofa', 341, nan)
('n02391049 zebra', 340, nan)
('n02389026 sorrel', 339, nan)

MMdnn:
('n02051845 pelican', 144, 0.77398205)
('n01616318 vulture', 23, 0.10650859)
('n01608432 kite', 21, 0.08107732)
('n02058221 albatross, mollymawk', 146, 0.0092755165)
('n03388043 fountain', 562, 0.008964523)

To search the reason, I read the code in CNTK layers. And I find:

# TODO: map_rank is broken. We should specify the #slowest-changing axes. E.g. 1 would work for images and vectors. Requires C++ change.
def BatchNormalization(map_rank=default_override_or(None),  # if given then normalize only over this many dimensions. E.g. pass 1 to tie all (h,w) in a (C, H, W)-shaped input
                       init_scale=1,
                       normalization_time_constant=default_override_or(5000), blend_time_constant=0,
                       epsilon=default_override_or(0.00001), use_cntk_engine=default_override_or(False),
                       disable_regularization=default_override_or(False),
                       name=''):

MMdnn IR is in channel last (eg. [, 112, 112, 64]) and we also emit the CNTK model code file in channel last . So it cannot assign the mean and variance using your code, or you will assign an mean array with shape[64,] into a bn.aggregate_mean with shape[112,]

If you want to finetune your CNTK model and also use the converted weight (.npy) , There are two ways:

  1. You need change your cntk_emitter into channel first mode and change the code in def batch_normalization(input, name, epsilon, **kwargs): part, then generate your model code file.
  2. Just edit your converted code file into channel first mode.

Please ping us if you have any problem ! Thanks!

@jeandebleau
Copy link
Author

thanks for the answer, indeed I also change the code to have a channel first representation ; which also involves modifying other part of the generated kit_imagenet.py.

@namizzz namizzz closed this as completed Sep 7, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants