From bf9785fa72a0f5c3eec5b516d5ea154ab07f12ab Mon Sep 17 00:00:00 2001 From: Andrew El-Kadi Date: Wed, 4 Dec 2024 14:19:21 +0000 Subject: [PATCH] Adding GenCast support. --- README.md | 220 +++-- gencast_demo_cloud_vm.ipynb | 690 ++++++++++++++ gencast_mini_demo.ipynb | 875 ++++++++++++++++++ graphcast/autoregressive.py | 2 +- graphcast/casting.py | 4 +- graphcast/deep_typed_graph_net.py | 122 ++- graphcast/denoiser.py | 851 +++++++++++++++++ graphcast/denoisers_base.py | 53 ++ .../GenCast_0p25deg_accelerator_scorecard.png | Bin 0 -> 171471 bytes .../GenCast_1p0deg_Mini_ENS_scorecard.png | Bin 0 -> 158143 bytes graphcast/docs/cloud_vm_setup.md | 242 +++++ graphcast/docs/local_runtime_popup_1.png | Bin 0 -> 92915 bytes graphcast/docs/local_runtime_popup_2.png | Bin 0 -> 96810 bytes graphcast/docs/local_runtime_url.png | Bin 0 -> 792222 bytes graphcast/docs/project.png | Bin 0 -> 156767 bytes graphcast/docs/provision_tpu.png | Bin 0 -> 40113 bytes graphcast/docs/tpu_types.png | Bin 0 -> 27827 bytes graphcast/dpm_solver_plus_plus_2s.py | 187 ++++ graphcast/gencast.py | 284 ++++++ graphcast/icosahedral_mesh.py | 4 + graphcast/mlp.py | 45 + graphcast/model_utils.py | 34 + graphcast/nan_cleaning.py | 125 +++ graphcast/rollout.py | 196 +++- graphcast/samplers_base.py | 47 + graphcast/samplers_utils.py | 431 +++++++++ graphcast/sparse_transformer.py | 577 ++++++++++++ graphcast/sparse_transformer_utils.py | 76 ++ graphcast/transformer.py | 124 +++ graphcast_demo.ipynb | 28 +- setup.py | 4 +- 31 files changed, 5110 insertions(+), 111 deletions(-) create mode 100644 gencast_demo_cloud_vm.ipynb create mode 100644 gencast_mini_demo.ipynb create mode 100644 graphcast/denoiser.py create mode 100644 graphcast/denoisers_base.py create mode 100644 graphcast/docs/GenCast_0p25deg_accelerator_scorecard.png create mode 100644 graphcast/docs/GenCast_1p0deg_Mini_ENS_scorecard.png create mode 100644 graphcast/docs/cloud_vm_setup.md create mode 100644 graphcast/docs/local_runtime_popup_1.png create mode 100644 graphcast/docs/local_runtime_popup_2.png create mode 100644 graphcast/docs/local_runtime_url.png create mode 100644 graphcast/docs/project.png create mode 100644 graphcast/docs/provision_tpu.png create mode 100644 graphcast/docs/tpu_types.png create mode 100644 graphcast/dpm_solver_plus_plus_2s.py create mode 100644 graphcast/gencast.py create mode 100644 graphcast/mlp.py create mode 100644 graphcast/nan_cleaning.py create mode 100644 graphcast/samplers_base.py create mode 100644 graphcast/samplers_utils.py create mode 100644 graphcast/sparse_transformer.py create mode 100644 graphcast/sparse_transformer_utils.py create mode 100644 graphcast/transformer.py diff --git a/README.md b/README.md index c59044cf..2596691e 100644 --- a/README.md +++ b/README.md @@ -1,68 +1,42 @@ -# GraphCast: Learning skillful medium-range global weather forecasting +# Google DeepMind GraphCast and GenCast -This package contains example code to run and train [GraphCast](https://www.science.org/doi/10.1126/science.adi2336). -It also provides three pretrained models: +This package contains example code to run and train the weather models used in the research papers [GraphCast](https://www.science.org/doi/10.1126/science.adi2336) and [GenCast](https://arxiv.org/abs/2312.15796). -1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree -resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017, - -2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree -resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from -1979 to 2015, useful to run a model with lower memory and compute constraints, - -3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 -pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on -HRES data from 2016 to 2021. This model can be initialized from HRES data (does -not require precipitation inputs). - -The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast). +It also provides pretrained model weights, normalization statistics and example input data on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast). Full model training requires downloading the [ERA5](https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5) dataset, available from [ECMWF](https://www.ecmwf.int/). This can best be -accessed as Zarr from [Weatherbench2's ERA5 data](https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5) (see the 6h downsampled versions). +accessed as Zarr from [Weatherbench2's ERA5 data](https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5). -## Overview of files +Data for operational fine-tuning can similarly be accessed at [Weatherbench2's HRES 0th frame data](https://weatherbench2.readthedocs.io/en/latest/data-guide.html#ifs-hres-t-0-analysis). -The best starting point is to open `graphcast_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb), which gives an -example of loading data, generating random weights or load a pre-trained -snapshot, generating predictions, computing the loss and computing gradients. -The one-step implementation of GraphCast architecture, is provided in -`graphcast.py`. +These datasets may be governed by separate terms and conditions or license provisions. Your use of such third-party materials is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use. -### Brief description of library files: +## Overview of files common to models -* `autoregressive.py`: Wrapper used to run (and train) the one-step GraphCast +* `autoregressive.py`: Wrapper used to run (and train) the one-step predictions to produce a sequence of predictions by auto-regressively feeding the outputs back as inputs at each step, in JAX a differentiable way. -* `casting.py`: Wrapper used around GraphCast to make it work using - BFloat16 precision. * `checkpoint.py`: Utils to serialize and deserialize trees. * `data_utils.py`: Utils for data preprocessing. * `deep_typed_graph_net.py`: General purpose deep graph neural network (GNN) that operates on `TypedGraph`'s where both inputs and outputs are flat - vectors of features for each of the nodes and edges. `graphcast.py` uses - three of these for the Grid2Mesh GNN, the Multi-mesh GNN and the Mesh2Grid - GNN, respectively. -* `graphcast.py`: The main GraphCast model architecture for one-step of - predictions. + vectors of features for each of the nodes and edges. * `grid_mesh_connectivity.py`: Tools for converting between regular grids on a sphere and triangular meshes. * `icosahedral_mesh.py`: Definition of an icosahedral multi-mesh. * `losses.py`: Loss computations, including latitude-weighting. +* `mlp.py`: Utils for building MLPs with norm conditioning layers. * `model_utils.py`: Utilities to produce flat node and edge vector features from input grid data, and to manipulate the node output vectors back into a multilevel grid data. -* `normalization.py`: Wrapper for the one-step GraphCast used to normalize - inputs according to historical values, and targets according to historical - time differences. -* `predictor_base.py`: Defines the interface of the predictor, which GraphCast +* `normalization.py`: Wrapper used to normalize inputs according to historical + values, and targets according to historical time differences. +* `predictor_base.py`: Defines the interface of the predictor, which models and all of the wrappers implement. * `rollout.py`: Similar to `autoregressive.py` but used only at inference time using a python loop to produce longer, but non-differentiable trajectories. -* `solar_radiation.py`: Computes Top-Of-the-Atmosphere (TOA) incident solar - radiation compatible with ERA5. This is used as a forcing variable and thus - needs to be computed for target lead times in an operational setting. * `typed_graph.py`: Definition of `TypedGraph`'s. * `typed_graph_net.py`: Implementation of simple graph neural network building blocks defined over `TypedGraph`'s that can be combined to build @@ -71,11 +45,115 @@ The one-step implementation of GraphCast architecture, is provided in * `xarray_tree.py`: An implementation of tree.map_structure that works with `xarray`s. +## GenCast: Diffusion-based ensemble forecasting for medium-range weather + +This package provides four pretrained models: + +1. `GenCast 0p25deg <2019`, GenCast model at 0.25deg resolution with 13 +pressure levels and a 6 times refined icosahedral mesh. This model is trained on +ERA5 data from 1979 to 2018 (inclusive), and can be causally evaluated on 2019 +and later years. This model was described in the paper +`GenCast: Diffusion-based ensemble forecasting for medium-range weather` +(https://arxiv.org/abs/2312.15796) + +2. `GenCast 0p25deg Operational <2019`, GenCast model at 0.25deg resolution, with 13 pressure levels and a 6 +times refined icosahedral mesh. This model is trained on ERA5 data from +1979 to 2018, and fine-tuned on HRES-fc0 data from +2016 to 2021 and can be causally evaluated on 2022 and later years. +This model can make predictions in an operational setting (i.e., initialised +from HRES-fc0) + +3. `GenCast 1p0deg <2019`, GenCast model at 1deg resolution, with 13 pressure +levels and a 5 times refined icosahedral mesh. This model is +trained on ERA5 data from 1979 to 2018, and can be causally evaluated on 2019 and later years. +This model has a smaller memory footprint than the 0.25deg models + +4. `GenCast 1p0deg Mini <2019`, GenCast model at 1deg resolution, with 13 pressure levels and a +4 times refined icosahedral mesh. This model is trained on ERA5 data +from 1979 to 2018, and can be causally evaluated on 2019 and later years. +This model has the smallest memory footprint of those provided and has been +provided to enable low cost demonstrations (for example, it is runnable in a free Colab notebook). +While its performance is reasonable, it is not representative of the performance +of the GenCast models (1-3) above. For reference, a scorecard comparing its performance to ENS can be found in [docs/](https://github.com/google-deepmind/graphcast/docs/GenCast_1p0deg_Mini_ENS_scorecard.png). Note that in this scorecard, +GenCast Mini only uses 8 member ensembles (vs. ENS' 50) so we use the fair (unbiased) +CRPS to allow for fair comparison. + +The best starting point is to open `gencast_mini_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/gencast_mini_demo.ipynb), which gives an +example of loading data, generating random weights or loading a `GenCast 1p0deg Mini <2019` +snapshot, generating predictions, computing the loss and computing gradients. +The one-step implementation of GenCast architecture is provided in +`gencast.py` and the relevant data, weights and statistics are in the `gencast/` +subdir of the Google Cloud Bucket. + +### Instructions for running GenCast on Google Cloud compute + +[cloud_vm_setup.md](https://github.com/google-deepmind/graphcast/blob/main/docs/cloud_vm_setup.md) +contains detailed instructions on launching a Google Cloud TPU VM. This provides +a means of running models (1-3) in the separate `gencast_demo_cloud_vm.ipynb` through [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/gencast_demo_cloud_vm.ipynb). + +### Brief description of relevant library files + +* `denoiser.py`: The GenCast denoiser for one step predictions. +* `denoisers_base.py`: Defines the interface of the denoiser. +* `dpm_solver_plus_plus_2s.py`: Sampler using DPM-Solver++ 2S from [1]. +* `gencast.py`: Combines the GenCast model architecture, wrapped as a + denoiser, with a sampler to generate predictions. +* `nan_cleaning.py`: Wraps a predictor to allow it to work with data + cleaned of NaNs. Used to remove NaNs from sea surface temperature. +* `samplers_base.py`: Defines the interface of the sampler. +* `samplers_utils.py`: Utility methods for the sampler. +* `sparse_transformer.py`: General purpose sparse transformer that + operates on `TypedGraph`'s where both inputs and outputs are flat vectors of + features for each of the nodes and edges. `predictor.py` uses one of these + for the mesh GNN. +* `sparse_transformer_utils.py`: Utility methods for the sparse + transformer. +* `transformer.py`: Wraps the mesh transformer, swapping the leading + two axes of the nodes in the input graph. + +[1] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic + Models, https://arxiv.org/abs/2211.01095 + +## GraphCast: Learning skillful medium-range global weather forecasting + +This package provides three pretrained models: + +1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree +resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017, + +2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree +resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from +1979 to 2015, useful to run a model with lower memory and compute constraints, + +3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 +pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on +HRES data from 2016 to 2021. This model can be initialized from HRES data (does +not require precipitation inputs). + +The best starting point is to open `graphcast_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb), which gives an +example of loading data, generating random weights or load a pre-trained +snapshot, generating predictions, computing the loss and computing gradients. +The one-step implementation of GraphCast architecture, is provided in +`graphcast.py` and the relevant data, weights and statistics are in the `graphcast/` +subdir of the Google Cloud Bucket. + +WARNING: For backwards compatibility, we have also left GraphCast data in the top level of the bucket. These will eventually be deleted in favour of the `graphcast/` subdir. + +### Brief description of relevant library files: + +* `casting.py`: Wrapper used around GraphCast to make it work using + BFloat16 precision. +* `graphcast.py`: The main GraphCast model architecture for one-step of + predictions. +* `solar_radiation.py`: Computes Top-Of-the-Atmosphere (TOA) incident solar + radiation compatible with ERA5. This is used as a forcing variable and thus + needs to be computed for target lead times in an operational setting. -### Dependencies. +## Dependencies. [Chex](https://github.com/deepmind/chex), [Dask](https://github.com/dask/dask), +[Dinosaur](https://github.com/google-research/dinosaur), [Haiku](https://github.com/deepmind/dm-haiku), [JAX](https://github.com/google/jax), [JAXline](https://github.com/deepmind/jaxline), @@ -85,39 +163,29 @@ The one-step implementation of GraphCast architecture, is provided in [Python](https://www.python.org/), [SciPy](https://scipy.org/), [Tree](https://github.com/deepmind/tree), -[Trimesh](https://github.com/mikedh/trimesh) and -[XArray](https://github.com/pydata/xarray). +[Trimesh](https://github.com/mikedh/trimesh) +[XArray](https://github.com/pydata/xarray) and +[XArray-TensorStore](https://github.com/google/xarray-tensorstore). -### License and attribution +## License and Disclaimers -The Colab notebook and the associated code are licensed under the Apache -License, Version 2.0. You may obtain a copy of the License at: -https://www.apache.org/licenses/LICENSE-2.0. +The Colab notebooks and the associated code are licensed under the Apache License, Version 2.0. You may obtain a copy of the License at: https://www.apache.org/licenses/LICENSE-2.0. -The model weights are made available for use under the terms of the Creative -Commons Attribution-NonCommercial-ShareAlike 4.0 International -(CC BY-NC-SA 4.0). You may obtain a copy of the License at: -https://creativecommons.org/licenses/by-nc-sa/4.0/. +The model weights are made available for use under the terms of the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You may obtain a copy of the License at: https://creativecommons.org/licenses/by-nc-sa/4.0/. -The weights were trained on ECMWF's ERA5 and HRES data. The colab includes a few -examples of ERA5 and HRES data that can be used as inputs to the models. -ECMWF data product are subject to the following terms: +This is not an officially supported Google product. -1. Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)". -2. Source www.ecmwf.int -3. Licence Statement: ECMWF data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/ -4. Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use. +Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY-NC-SA 4.0 licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses. -### Disclaimer +GenCast and GraphCast are part of an experimental research project. You are solely responsible for determining the appropriateness of using or distributing GenCast, GraphCast or any outputs generated and assume all risks associated with your use or distribution of GenCast, GraphCast and outputs and your exercise of rights and permissions granted by Google to you under the relevant License. Use discretion before relying on, publishing, downloading or otherwise using GenCast, GraphCast or any outputs generated. GenCast, GraphCast or any outputs generated (i) are not based on data published by; (ii) have not been produced in collaboration with; and (iii) have not been endorsed by any government meteorological agency or department and in no way replaces official alerts, warnings or notices published by such agencies. -This is not an officially supported Google product. +Copyright 2024 DeepMind Technologies Limited. -Copyright 2023 DeepMind Technologies Limited. -### Citation +## Citations -If you use this work, consider citing our paper ([blog post](https://deepmind.google/discover/blog/graphcast-ai-model-for-faster-and-more-accurate-global-weather-forecasting/), [Science](https://www.science.org/doi/10.1126/science.adi2336), [arXiv](https://arxiv.org/abs/2212.12794)): +If you use this work, consider citing our papers ([blog post](https://deepmind.google/discover/blog/graphcast-ai-model-for-faster-and-more-accurate-global-weather-forecasting/), [Science](https://www.science.org/doi/10.1126/science.adi2336), [arXiv](https://arxiv.org/abs/2212.12794), [arxiv GenCast](https://arxiv.org/abs/2312.15796)): ```latex @article{lam2023learning, @@ -131,3 +199,31 @@ If you use this work, consider citing our paper ([blog post](https://deepmind.go publisher={American Association for the Advancement of Science} } ``` + + +```latex +@article{price2023gencast, + title={GenCast: Diffusion-based ensemble forecasting for medium-range weather}, + author={Price, Ilan and Sanchez-Gonzalez, Alvaro and Alet, Ferran and Andersson, Tom R and El-Kadi, Andrew and Masters, Dominic and Ewalds, Timo and Stott, Jacklynn and Mohamed, Shakir and Battaglia, Peter and Lam, Remi and Willson, Matthew}, + journal={arXiv preprint arXiv:2312.15796}, + year={2023} +} +``` + +## Acknowledgements + +The (i) GenCast and GraphCast communicate with and/or reference with the following separate libraries and packages and the colab notebooks include a few examples of ECMWF’s ERA5 and HRES data that can be used as input to the models. +Data and products of the European Centre for Medium-range Weather Forecasts (ECMWF), as modified by Google. +Modified Copernicus Climate Change Service information 2023. Neither the European Commission nor ECMWF is responsible for any use that may be made of the Copernicus information or data it contains. +ECMWF HRES datasets +Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)". +Source: www.ecmwf.int +License Statement: ECMWF open data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/ +Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use. + +Use of the third-party materials referred to above may be governed by separate terms and conditions or license provisions. Your use of the third-party materials is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use. + + +## Contact + +For feedback and questions, contact us at gencast@google.com. diff --git a/gencast_demo_cloud_vm.ipynb b/gencast_demo_cloud_vm.ipynb new file mode 100644 index 00000000..122baa79 --- /dev/null +++ b/gencast_demo_cloud_vm.ipynb @@ -0,0 +1,690 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9KHpdDlzYNDI" + }, + "source": [ + "\u003e Copyright 2024 DeepMind Technologies Limited.\n", + "\u003e\n", + "\u003e Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "\u003e you may not use this file except in compliance with the License.\n", + "\u003e You may obtain a copy of the License at\n", + "\u003e\n", + "\u003e http://www.apache.org/licenses/LICENSE-2.0\n", + "\u003e\n", + "\u003e Unless required by applicable law or agreed to in writing, software\n", + "\u003e distributed under the License is distributed on an \"AS-IS\" BASIS,\n", + "\u003e WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "\u003e See the License for the specific language governing permissions and\n", + "\u003e limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GnMHywhUhlgJ" + }, + "source": [ + "# GenCast Demo\n", + "\n", + "This notebook demonstrates running all GenCast models provided in the repository:\n", + "\n", + "1. `GenCast 0p25deg \u003c2019`\n", + "2. `GenCast 0p25deg Operational \u003c2019`\n", + "3. `GenCast 1p0deg \u003c2019`\n", + "4. `GenCast 1p0deg Mini \u003c2019`\n", + "\n", + "While `GenCast 1p0deg Mini \u003c2019` is runnable with the freely provided TPUv2-8 configuration in Colab, the other models require compute that can be accessed via Google Cloud.\n", + "\n", + "See [cloud_vm_setup.md](https://github.com/google-deepmind/graphcast/blob/main/docs/cloud_vm_setup.md) for detailed instructions on launching a Google Cloud TPU VM and connecting to it via this notebook. This document also provides some more information on the memory requirements of the models.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yMbbXFl4msJw" + }, + "source": [ + "# Installation and Initialization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "233zaiZYqCnc" + }, + "outputs": [], + "source": [ + "# @title Pip install repo and dependencies\n", + "\n", + "%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z_j8ej4Pyg1L" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import dataclasses\n", + "import datetime\n", + "import math\n", + "from typing import Optional\n", + "import haiku as hk\n", + "from IPython.display import HTML\n", + "from IPython import display\n", + "import ipywidgets as widgets\n", + "import jax\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import animation\n", + "import numpy as np\n", + "import xarray\n", + "\n", + "from graphcast import rollout\n", + "from graphcast import xarray_jax\n", + "from graphcast import normalization\n", + "from graphcast import checkpoint\n", + "from graphcast import data_utils\n", + "from graphcast import xarray_tree\n", + "from graphcast import gencast\n", + "from graphcast import denoiser\n", + "from graphcast import nan_cleaning\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OzYgQ0QN-kn8" + }, + "outputs": [], + "source": [ + "# @title Plotting functions\n", + "\n", + "def select(\n", + " data: xarray.Dataset,\n", + " variable: str,\n", + " level: Optional[int] = None,\n", + " max_steps: Optional[int] = None\n", + " ) -\u003e xarray.Dataset:\n", + " data = data[variable]\n", + " if \"batch\" in data.dims:\n", + " data = data.isel(batch=0)\n", + " if max_steps is not None and \"time\" in data.sizes and max_steps \u003c data.sizes[\"time\"]:\n", + " data = data.isel(time=range(0, max_steps))\n", + " if level is not None and \"level\" in data.coords:\n", + " data = data.sel(level=level)\n", + " return data\n", + "\n", + "def scale(\n", + " data: xarray.Dataset,\n", + " center: Optional[float] = None,\n", + " robust: bool = False,\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + " vmin = np.nanpercentile(data, (2 if robust else 0))\n", + " vmax = np.nanpercentile(data, (98 if robust else 100))\n", + " if center is not None:\n", + " diff = max(vmax - center, center - vmin)\n", + " vmin = center - diff\n", + " vmax = center + diff\n", + " return (data, matplotlib.colors.Normalize(vmin, vmax),\n", + " (\"RdBu_r\" if center is not None else \"viridis\"))\n", + "\n", + "def plot_data(\n", + " data: dict[str, xarray.Dataset],\n", + " fig_title: str,\n", + " plot_size: float = 5,\n", + " robust: bool = False,\n", + " cols: int = 4\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + "\n", + " first_data = next(iter(data.values()))[0]\n", + " max_steps = first_data.sizes.get(\"time\", 1)\n", + " assert all(max_steps == d.sizes.get(\"time\", 1) for d, _, _ in data.values())\n", + "\n", + " cols = min(cols, len(data))\n", + " rows = math.ceil(len(data) / cols)\n", + " figure = plt.figure(figsize=(plot_size * 2 * cols,\n", + " plot_size * rows))\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " figure.subplots_adjust(wspace=0, hspace=0)\n", + " figure.tight_layout()\n", + "\n", + " images = []\n", + " for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):\n", + " ax = figure.add_subplot(rows, cols, i+1)\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.set_title(title)\n", + " im = ax.imshow(\n", + " plot_data.isel(time=0, missing_dims=\"ignore\"), norm=norm,\n", + " origin=\"lower\", cmap=cmap)\n", + " plt.colorbar(\n", + " mappable=im,\n", + " ax=ax,\n", + " orientation=\"vertical\",\n", + " pad=0.02,\n", + " aspect=16,\n", + " shrink=0.75,\n", + " cmap=cmap,\n", + " extend=(\"both\" if robust else \"neither\"))\n", + " images.append(im)\n", + "\n", + " def update(frame):\n", + " if \"time\" in first_data.dims:\n", + " td = datetime.timedelta(microseconds=first_data[\"time\"][frame].item() / 1000)\n", + " figure.suptitle(f\"{fig_title}, {td}\", fontsize=16)\n", + " else:\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " for im, (plot_data, norm, cmap) in zip(images, data.values()):\n", + " im.set_data(plot_data.isel(time=frame, missing_dims=\"ignore\"))\n", + "\n", + " ani = animation.FuncAnimation(\n", + " fig=figure, func=update, frames=max_steps, interval=250)\n", + " plt.close(figure.number)\n", + " return HTML(ani.to_jshtml())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rQWk0RRuCjDN" + }, + "source": [ + "# Load the Data and initialize the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jTRvMoexzjYm" + }, + "outputs": [], + "source": [ + "# @title Set paths\n", + "\n", + "MODEL_PATH = \"\" # E.g. \"GenCast 1p0deg _2019.npz\"\n", + "DATA_PATH = \"\" # E.g. \"source-era5_date-2019-03-29_res-1.0_levels-13_steps-04.nc\"\n", + "STATS_DIR = \"\" # E.g. \"stats/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cgfYjE1YhALA" + }, + "outputs": [], + "source": [ + "# @title Load the model\n", + "\n", + "with open(MODEL_PATH, \"rb\") as f:\n", + " ckpt = checkpoint.load(f, gencast.CheckPoint)\n", + "params = ckpt.params\n", + "state = {}\n", + "\n", + "task_config = ckpt.task_config\n", + "sampler_config = ckpt.sampler_config\n", + "noise_config = ckpt.noise_config\n", + "noise_encoder_config = ckpt.noise_encoder_config\n", + "denoiser_architecture_config = ckpt.denoiser_architecture_config\n", + "print(\"Model description:\\n\", ckpt.description, \"\\n\")\n", + "print(\"Model license:\\n\", ckpt.license, \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z2AqgxUgiALy" + }, + "source": [ + "## Load the example data\n", + "\n", + "Example ERA5 datasets are available at 0.25 degree and 1 degree resolution.\n", + "\n", + "Example HRES-fc0 datasets are available at 0.25 degree resolution.\n", + "\n", + "Some transformations were done from the base datasets:\n", + "- We accumulated precipitation over 12 hours instead of the default 1 hour.\n", + "- For HRES-fc0 sea surface temperature, we assigned NaNs to grid cells in which sea surface temperature was NaN in the ERA5 dataset (this remains fixed at all times).\n", + "\n", + "The data resolution must match the model that is loaded.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5XGzOww0y_BC" + }, + "outputs": [], + "source": [ + "# @title Check example dataset matches model\n", + "\n", + "def parse_file_parts(file_name):\n", + " return dict(part.split(\"-\", 1) for part in file_name.split(\"_\"))\n", + "\n", + "def data_valid_for_model(file_name: str, params_file_name: str):\n", + " \"\"\"Check data type and resolution matches.\"\"\"\n", + " data_file_parts = parse_file_parts(file_name.removesuffix(\".nc\"))\n", + " res_matches = data_file_parts[\"res\"].replace(\".\", \"p\") in params_file_name.lower()\n", + " source_matches = \"Operational\" in params_file_name\n", + " if data_file_parts[\"source\"] == \"era5\":\n", + " source_matches = not source_matches\n", + " return res_matches and source_matches\n", + "\n", + "assert data_valid_for_model(DATA_PATH, MODEL_PATH)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yz-ekISoJxeZ" + }, + "outputs": [], + "source": [ + "# @title Load weather data\n", + "\n", + "with open(DATA_PATH, \"rb\") as f:\n", + " example_batch = xarray.load_dataset(f).compute()\n", + "\n", + "assert example_batch.dims[\"time\"] \u003e= 3 # 2 for input, \u003e=1 for targets\n", + "\n", + "print(\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(DATA_PATH.removesuffix(\".nc\")).items()]))\n", + "\n", + "example_batch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iqzXVpn9_b15" + }, + "outputs": [], + "source": [ + "# @title Plot example data\n", + "\n", + "plot_size = 7\n", + "variable = \"geopotential\"\n", + "level = 500\n", + "steps = example_batch.dims[\"time\"]\n", + "\n", + "\n", + "data = {\n", + " \" \": scale(select(example_batch, variable, level, steps), robust=True),\n", + "}\n", + "fig_title = variable\n", + "if \"level\" in example_batch[variable].coords:\n", + " fig_title += f\" at {level} hPa\"\n", + "\n", + "plot_data(data, fig_title, plot_size, robust=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "njD4jsPTPKvJ" + }, + "outputs": [], + "source": [ + "# @title Extract training and eval data\n", + "\n", + "train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"12h\", \"12h\"), # Only 1AR training.\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"12h\", f\"{(example_batch.dims['time']-2)*12}h\"), # All but 2 input frames.\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "print(\"All Examples: \", example_batch.dims.mapping)\n", + "print(\"Train Inputs: \", train_inputs.dims.mapping)\n", + "print(\"Train Targets: \", train_targets.dims.mapping)\n", + "print(\"Train Forcings:\", train_forcings.dims.mapping)\n", + "print(\"Eval Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Eval Targets: \", eval_targets.dims.mapping)\n", + "print(\"Eval Forcings: \", eval_forcings.dims.mapping)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-DJzie5me2-H" + }, + "outputs": [], + "source": [ + "# @title Load normalization data\n", + "\n", + "with open(STATS_DIR +\"diffs_stddev_by_level.nc\", \"rb\") as f:\n", + " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", + "with open(STATS_DIR +\"mean_by_level.nc\", \"rb\") as f:\n", + " mean_by_level = xarray.load_dataset(f).compute()\n", + "with open(STATS_DIR +\"stddev_by_level.nc\", \"rb\") as f:\n", + " stddev_by_level = xarray.load_dataset(f).compute()\n", + "with open(STATS_DIR +\"min_by_level.nc\", \"rb\") as f:\n", + " min_by_level = xarray.load_dataset(f).compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ke2zQyuT_sMA" + }, + "outputs": [], + "source": [ + "# @title Build jitted functions, and possibly initialize random weights\n", + "\n", + "\n", + "def construct_wrapped_gencast():\n", + " \"\"\"Constructs and wraps the GenCast Predictor.\"\"\"\n", + " predictor = gencast.GenCast(\n", + " sampler_config=sampler_config,\n", + " task_config=task_config,\n", + " denoiser_architecture_config=denoiser_architecture_config,\n", + " noise_config=noise_config,\n", + " noise_encoder_config=noise_encoder_config,\n", + " )\n", + "\n", + " predictor = normalization.InputsAndResiduals(\n", + " predictor,\n", + " diffs_stddev_by_level=diffs_stddev_by_level,\n", + " mean_by_level=mean_by_level,\n", + " stddev_by_level=stddev_by_level,\n", + " )\n", + "\n", + " predictor = nan_cleaning.NaNCleaner(\n", + " predictor=predictor,\n", + " reintroduce_nans=True,\n", + " fill_value=min_by_level,\n", + " var_to_clean='sea_surface_temperature',\n", + " )\n", + "\n", + " return predictor\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def run_forward(inputs, targets_template, forcings):\n", + " predictor = construct_wrapped_gencast()\n", + " return predictor(inputs, targets_template=targets_template, forcings=forcings)\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def loss_fn(inputs, targets, forcings):\n", + " predictor = construct_wrapped_gencast()\n", + " loss, diagnostics = predictor.loss(inputs, targets, forcings)\n", + " return xarray_tree.map_structure(\n", + " lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),\n", + " (loss, diagnostics),\n", + " )\n", + "\n", + "\n", + "def grads_fn(params, state, inputs, targets, forcings):\n", + " def _aux(params, state, i, t, f):\n", + " (loss, diagnostics), next_state = loss_fn.apply(\n", + " params, state, jax.random.PRNGKey(0), i, t, f\n", + " )\n", + " return loss, (diagnostics, next_state)\n", + "\n", + " (loss, (diagnostics, next_state)), grads = jax.value_and_grad(\n", + " _aux, has_aux=True\n", + " )(params, state, inputs, targets, forcings)\n", + " return loss, diagnostics, next_state, grads\n", + "\n", + "\n", + "if params is None:\n", + " init_jitted = jax.jit(loss_fn.init)\n", + " params, state = init_jitted(\n", + " rng=jax.random.PRNGKey(0),\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings,\n", + " )\n", + "\n", + "\n", + "loss_fn_jitted = jax.jit(\n", + " lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]\n", + ")\n", + "grads_fn_jitted = jax.jit(grads_fn)\n", + "run_forward_jitted = jax.jit(\n", + " lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VBNutliiCyqA" + }, + "source": [ + "# Run the model\n", + "\n", + "The `chunked_prediction_generator_multiple_runs` iterates over forecast steps, where the 1 step forecast is jitted and samples are pmapped across the chips.\n", + "This allows us to make efficient use of all devices and parallelise generating an ensemble across them. We then combine the chunks at the end to form our final forecast.\n", + "\n", + "Note that the cell will take longer than the standard inference time to run when executed for the first time, as this will include code compilation time. This cost does not increase with the number of devices, it is a fixed-cost one time operation whose result can be reused across any number of devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t-6ik5tU1yr7" + }, + "outputs": [], + "source": [ + "# The number of ensemble members should be a multiple of the number of devices.\n", + "len(jax.local_devices())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7obeY9i9oTtD" + }, + "outputs": [], + "source": [ + "# @title Autoregressive rollout (loop in python)\n", + "\n", + "print(\"Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Targets: \", eval_targets.dims.mapping)\n", + "print(\"Forcings:\", eval_forcings.dims.mapping)\n", + "\n", + "num_ensemble_members = 8 # @param int\n", + "rng = jax.random.PRNGKey(0)\n", + "# We fold-in the ensemble member, this way the first N members should always\n", + "# match across different runs which use take the same inputs\n", + "# regardless of total ensemble size.\n", + "rngs = np.stack(\n", + " [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)\n", + "\n", + "chunks = []\n", + "for chunk in rollout.chunked_prediction_generator_multiple_runs(\n", + " xarray_jax.pmap(run_forward_jitted, dim=\"sample\"),\n", + " rngs=rngs,\n", + " inputs=eval_inputs,\n", + " targets_template=eval_targets * np.nan,\n", + " forcings=eval_forcings,\n", + " num_steps_per_chunk = 1,\n", + " num_samples = num_ensemble_members,\n", + " pmap_devices=jax.local_devices()\n", + " ):\n", + " chunks.append(chunk)\n", + "predictions = xarray.combine_by_coords(chunks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wn7dccXO5R7C" + }, + "outputs": [], + "source": [ + "# @title Plot prediction samples and diffs\n", + "\n", + "plot_size = 5\n", + "variable = \"2m_temperature\"\n", + "level = None\n", + "steps = predictions.dims[\"time\"]\n", + "\n", + "fig_title = variable\n", + "if \"level\" in predictions[variable].coords:\n", + " fig_title += f\" at {level} hPa\"\n", + "\n", + "for sample_idx in range(num_ensemble_members):\n", + " data = {\n", + " \"Targets\": scale(select(eval_targets, variable, level, steps), robust=True),\n", + " \"Predictions\": scale(select(predictions.isel(sample=sample_idx), variable, level, steps), robust=True),\n", + " \"Diff\": scale((select(eval_targets, variable, level, steps) -\n", + " select(predictions.isel(sample=sample_idx), variable, level, steps)),\n", + " robust=True, center=0),\n", + " }\n", + " display.display(plot_data(data, fig_title + f\", Sample {sample_idx}\", plot_size, robust=True))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X3m9lW5fN4oL" + }, + "outputs": [], + "source": [ + "# @title Plot ensemble mean and CRPS\n", + "\n", + "def crps(targets, predictions, bias_corrected = True):\n", + " if predictions.sizes.get(\"sample\", 1) \u003c 2:\n", + " raise ValueError(\n", + " \"predictions must have dim 'sample' with size at least 2.\")\n", + " sum_dims = [\"sample\", \"sample2\"]\n", + " preds2 = predictions.rename({\"sample\": \"sample2\"})\n", + " num_samps = predictions.sizes[\"sample\"]\n", + " num_samps2 = (num_samps - 1) if bias_corrected else num_samps\n", + " mean_abs_diff = np.abs(\n", + " predictions - preds2).sum(\n", + " dim=sum_dims, skipna=False) / (num_samps * num_samps2)\n", + " mean_abs_err = np.abs(targets - predictions).sum(dim=\"sample\", skipna=False) / num_samps\n", + " return mean_abs_err - 0.5 * mean_abs_diff\n", + "\n", + "\n", + "plot_size = 5\n", + "variable = \"2m_temperature\"\n", + "level = None\n", + "steps = predictions.dims[\"time\"]\n", + "\n", + "fig_title = variable\n", + "if \"level\" in predictions[variable].coords:\n", + " fig_title += f\" at {level} hPa\"\n", + "\n", + "data = {\n", + " \"Targets\": scale(select(eval_targets, variable, level, steps), robust=True),\n", + " \"Ensemble Mean\": scale(select(predictions.mean(dim=[\"sample\"]), variable, level, steps), robust=True),\n", + " \"Ensemble CRPS\": scale(crps((select(eval_targets, variable, level, steps)),\n", + " select(predictions, variable, level, steps)),\n", + " robust=True, center=0),\n", + "}\n", + "display.display(plot_data(data, fig_title, plot_size, robust=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZLI0DhWog3Rg" + }, + "outputs": [], + "source": [ + "# @title (Optional) Save the predictions.\n", + "predictions.to_zarr(\"predictions.zarr\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O6ZhRFBPD0kq" + }, + "source": [ + "# Train the model\n", + "\n", + "The following operations requires larger amounts of memory than running inference.\n", + "\n", + "The first time executing the cell takes more time, as it includes the time to jit the function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nv-u3dAP7IRZ" + }, + "outputs": [], + "source": [ + "# @title Loss computation\n", + "loss, diagnostics = loss_fn_jitted(\n", + " jax.random.PRNGKey(0),\n", + " train_inputs,\n", + " train_targets,\n", + " train_forcings)\n", + "print(\"Loss:\", float(loss))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mBNFq1IGZNLz" + }, + "outputs": [], + "source": [ + "# @title Gradient computation\n", + "loss, diagnostics, next_state, grads = grads_fn_jitted(\n", + " params=params,\n", + " state=state,\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings)\n", + "mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", + "print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//gdm/weather/colab_base:weather_notebook", + "kind": "private" + }, + "name": "GenCast Cloud VM", + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/gencast_mini_demo.ipynb b/gencast_mini_demo.ipynb new file mode 100644 index 00000000..4d0959ac --- /dev/null +++ b/gencast_mini_demo.ipynb @@ -0,0 +1,875 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9KHpdDlzYNDI" + }, + "source": [ + "\u003e Copyright 2024 DeepMind Technologies Limited.\n", + "\u003e\n", + "\u003e Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "\u003e you may not use this file except in compliance with the License.\n", + "\u003e You may obtain a copy of the License at\n", + "\u003e\n", + "\u003e http://www.apache.org/licenses/LICENSE-2.0\n", + "\u003e\n", + "\u003e Unless required by applicable law or agreed to in writing, software\n", + "\u003e distributed under the License is distributed on an \"AS-IS\" BASIS,\n", + "\u003e WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "\u003e See the License for the specific language governing permissions and\n", + "\u003e limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GnMHywhUhlgJ" + }, + "source": [ + "# GenCast Mini Demo\n", + "\n", + "This notebook demonstrates running `GenCast 1p0deg Mini \u003c2019`.\n", + "\n", + "`GenCast 1p0deg Mini \u003c2019` is a GenCast model at 1deg resolution, with 13 pressure levels and a 4 times refined icosahedral mesh. It is trained on ERA5 data from 1979 to 2018, and can be causally evaluated on 2019 and later years.\n", + "\n", + "While other GenCast models are [available](https://github.com/google-deepmind/graphcast/blob/main/README.md), this model has the smallest memory footprint of those provided and is the only one runnable with the freely provided TPUv2-8 configuration in Colab. You can select this configuration in `Runtime\u003eChange Runtime Type`.\n", + "\n", + "**N.B.** The performance of `GenCast 1p0deg Mini \u003c2019` is reasonable but is not representative of the performance of the other GenCast models described in the [README](https://github.com/google-deepmind/graphcast/blob/main/README.md).\n", + "\n", + "To run the other models using Google Cloud Compute, refer to [gencast_demo_cloud_vm.ipynb](https://colab.research.google.com/github/deepmind/graphcast/blob/master/gencast_demo_cloud_vm.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yMbbXFl4msJw" + }, + "source": [ + "# Installation and Initialization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-gAH79SRwp9G" + }, + "outputs": [], + "source": [ + "# @title Upgrade packages (kernel needs to be restarted after running this cell).\n", + "\n", + "%pip install -U importlib_metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "233zaiZYqCnc" + }, + "outputs": [], + "source": [ + "# @title Pip install repo and dependencies\n", + "\n", + "%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z_j8ej4Pyg1L" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import dataclasses\n", + "import datetime\n", + "import math\n", + "from google.cloud import storage\n", + "from typing import Optional\n", + "import haiku as hk\n", + "from IPython.display import HTML\n", + "from IPython import display\n", + "import ipywidgets as widgets\n", + "import jax\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import animation\n", + "import numpy as np\n", + "import xarray\n", + "\n", + "from graphcast import rollout\n", + "from graphcast import xarray_jax\n", + "from graphcast import normalization\n", + "from graphcast import checkpoint\n", + "from graphcast import data_utils\n", + "from graphcast import xarray_tree\n", + "from graphcast import gencast\n", + "from graphcast import denoiser\n", + "from graphcast import nan_cleaning\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "OzYgQ0QN-kn8" + }, + "outputs": [], + "source": [ + "# @title Plotting functions\n", + "\n", + "def select(\n", + " data: xarray.Dataset,\n", + " variable: str,\n", + " level: Optional[int] = None,\n", + " max_steps: Optional[int] = None\n", + " ) -\u003e xarray.Dataset:\n", + " data = data[variable]\n", + " if \"batch\" in data.dims:\n", + " data = data.isel(batch=0)\n", + " if max_steps is not None and \"time\" in data.sizes and max_steps \u003c data.sizes[\"time\"]:\n", + " data = data.isel(time=range(0, max_steps))\n", + " if level is not None and \"level\" in data.coords:\n", + " data = data.sel(level=level)\n", + " return data\n", + "\n", + "def scale(\n", + " data: xarray.Dataset,\n", + " center: Optional[float] = None,\n", + " robust: bool = False,\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + " vmin = np.nanpercentile(data, (2 if robust else 0))\n", + " vmax = np.nanpercentile(data, (98 if robust else 100))\n", + " if center is not None:\n", + " diff = max(vmax - center, center - vmin)\n", + " vmin = center - diff\n", + " vmax = center + diff\n", + " return (data, matplotlib.colors.Normalize(vmin, vmax),\n", + " (\"RdBu_r\" if center is not None else \"viridis\"))\n", + "\n", + "def plot_data(\n", + " data: dict[str, xarray.Dataset],\n", + " fig_title: str,\n", + " plot_size: float = 5,\n", + " robust: bool = False,\n", + " cols: int = 4\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + "\n", + " first_data = next(iter(data.values()))[0]\n", + " max_steps = first_data.sizes.get(\"time\", 1)\n", + " assert all(max_steps == d.sizes.get(\"time\", 1) for d, _, _ in data.values())\n", + "\n", + " cols = min(cols, len(data))\n", + " rows = math.ceil(len(data) / cols)\n", + " figure = plt.figure(figsize=(plot_size * 2 * cols,\n", + " plot_size * rows))\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " figure.subplots_adjust(wspace=0, hspace=0)\n", + " figure.tight_layout()\n", + "\n", + " images = []\n", + " for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):\n", + " ax = figure.add_subplot(rows, cols, i+1)\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.set_title(title)\n", + " im = ax.imshow(\n", + " plot_data.isel(time=0, missing_dims=\"ignore\"), norm=norm,\n", + " origin=\"lower\", cmap=cmap)\n", + " plt.colorbar(\n", + " mappable=im,\n", + " ax=ax,\n", + " orientation=\"vertical\",\n", + " pad=0.02,\n", + " aspect=16,\n", + " shrink=0.75,\n", + " cmap=cmap,\n", + " extend=(\"both\" if robust else \"neither\"))\n", + " images.append(im)\n", + "\n", + " def update(frame):\n", + " if \"time\" in first_data.dims:\n", + " td = datetime.timedelta(microseconds=first_data[\"time\"][frame].item() / 1000)\n", + " figure.suptitle(f\"{fig_title}, {td}\", fontsize=16)\n", + " else:\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " for im, (plot_data, norm, cmap) in zip(images, data.values()):\n", + " im.set_data(plot_data.isel(time=frame, missing_dims=\"ignore\"))\n", + "\n", + " ani = animation.FuncAnimation(\n", + " fig=figure, func=update, frames=max_steps, interval=250)\n", + " plt.close(figure.number)\n", + " return HTML(ani.to_jshtml())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rQWk0RRuCjDN" + }, + "source": [ + "# Load the Data and initialize the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ttMHeiGCppjB" + }, + "outputs": [], + "source": [ + "# @title Authenticate with Google Cloud Storage\n", + "\n", + "# Gives you an authenticated client, in case you want to use a private bucket.\n", + "gcs_client = storage.Client.create_anonymous_client()\n", + "gcs_bucket = gcs_client.get_bucket(\"dm_graphcast\")\n", + "dir_prefix = \"gencast/\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ty5WSDRjhDBF" + }, + "source": [ + "## Load the model params\n", + "\n", + "Choose one of the two ways of getting model params:\n", + "- **random**: You'll get random predictions, but you can change the model architecture, which may run faster or fit on your device.\n", + "- **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PMoFXuZXs-xg" + }, + "outputs": [], + "source": [ + "# @title Choose the model\n", + "\n", + "params_file_options = [\n", + " name for blob in gcs_bucket.list_blobs(prefix=(dir_prefix+\"params/\"))\n", + " if (name := blob.name.removeprefix(dir_prefix+\"params/\"))] # Drop empty string.\n", + "\n", + "latent_value_options = [int(2**i) for i in range(4, 10)]\n", + "random_latent_size = widgets.Dropdown(\n", + " options=latent_value_options, value=512,description=\"Latent size:\")\n", + "random_attention_type = widgets.Dropdown(\n", + " options=[\"splash_mha\", \"triblockdiag_mha\", \"mha\"], value=\"splash_mha\", description=\"Attention:\")\n", + "random_mesh_size = widgets.IntSlider(\n", + " value=4, min=4, max=6, description=\"Mesh size:\")\n", + "random_num_heads = widgets.Dropdown(\n", + " options=[int(2**i) for i in range(0, 3)], value=4,description=\"Num heads:\")\n", + "random_attention_k_hop = widgets.Dropdown(\n", + " options=[int(2**i) for i in range(2, 5)], value=16,description=\"Attn k hop:\")\n", + "\n", + "def update_latent_options(*args):\n", + " def _latent_valid_for_attn(attn, latent, heads):\n", + " head_dim, rem = divmod(latent, heads)\n", + " if rem != 0:\n", + " return False\n", + " # Required for splash attn.\n", + " if head_dim % 128 != 0:\n", + " return attn != \"splash_mha\"\n", + " return True\n", + " attn = random_attention_type.value\n", + " heads = random_num_heads.value\n", + " random_latent_size.options = [\n", + " latent for latent in latent_value_options\n", + " if _latent_valid_for_attn(attn, latent, heads)]\n", + "\n", + "# Observe changes to only allow for valid combinations.\n", + "random_attention_type.observe(update_latent_options, \"value\")\n", + "random_latent_size.observe(update_latent_options, \"value\")\n", + "random_num_heads.observe(update_latent_options, \"value\")\n", + "\n", + "params_file = widgets.Dropdown(\n", + " options=[f for f in params_file_options if \"Mini\" in f],\n", + " description=\"Params file:\",\n", + " layout={\"width\": \"max-content\"})\n", + "\n", + "source_tab = widgets.Tab([\n", + " widgets.VBox([\n", + " random_attention_type,\n", + " random_mesh_size,\n", + " random_num_heads,\n", + " random_latent_size,\n", + " random_attention_k_hop\n", + " ]),\n", + " params_file,\n", + "])\n", + "source_tab.set_title(0, \"Random\")\n", + "source_tab.set_title(1, \"Checkpoint\")\n", + "widgets.VBox([\n", + " source_tab,\n", + " widgets.Label(value=\"Run the next cell to load the model. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "cgfYjE1YhALA" + }, + "outputs": [], + "source": [ + "# @title Load the model\n", + "\n", + "source = source_tab.get_title(source_tab.selected_index)\n", + "\n", + "if source == \"Random\":\n", + " params = None # Filled in below\n", + " state = {}\n", + " task_config = gencast.TASK\n", + " # Use default values.\n", + " sampler_config = gencast.SamplerConfig()\n", + " noise_config = gencast.NoiseConfig()\n", + " noise_encoder_config = denoiser.NoiseEncoderConfig()\n", + " # Configure, otherwise use default values.\n", + " denoiser_architecture_config = denoiser.DenoiserArchitectureConfig(\n", + " sparse_transformer_config = denoiser.SparseTransformerConfig(\n", + " attention_k_hop=random_attention_k_hop.value,\n", + " attention_type=random_attention_type.value,\n", + " d_model=random_latent_size.value,\n", + " num_heads=random_num_heads.value\n", + " ),\n", + " mesh_size=random_mesh_size.value,\n", + " latent_size=random_latent_size.value,\n", + " )\n", + "else:\n", + " assert source == \"Checkpoint\"\n", + " with gcs_bucket.blob(dir_prefix + f\"params/{params_file.value}\").open(\"rb\") as f:\n", + " ckpt = checkpoint.load(f, gencast.CheckPoint)\n", + " params = ckpt.params\n", + " state = {}\n", + "\n", + " task_config = ckpt.task_config\n", + " sampler_config = ckpt.sampler_config\n", + " noise_config = ckpt.noise_config\n", + " noise_encoder_config = ckpt.noise_encoder_config\n", + " denoiser_architecture_config = ckpt.denoiser_architecture_config\n", + " print(\"Model description:\\n\", ckpt.description, \"\\n\")\n", + " print(\"Model license:\\n\", ckpt.license, \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z2AqgxUgiALy" + }, + "source": [ + "## Load the example data\n", + "\n", + "Example ERA5 datasets are available at 0.25 degree and 1 degree resolution.\n", + "\n", + "Example HRES-fc0 datasets are available at 0.25 degree resolution.\n", + "\n", + "Some transformations were done from the base datasets:\n", + "- We accumulated precipitation over 12 hours instead of the default 1 hour.\n", + "- For HRES-fc0 sea surface temperature, we assigned NaNs to grid cells in which sea surface temperature was NaN in the ERA5 dataset (this remains fixed at all times).\n", + "\n", + "The data resolution must match the model that is loaded. Since we are running GenCast Mini, this will be 1 degree.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "5XGzOww0y_BC" + }, + "outputs": [], + "source": [ + "# @title Get and filter the list of available example datasets\n", + "\n", + "dataset_file_options = [\n", + " name for blob in gcs_bucket.list_blobs(prefix=(dir_prefix + \"dataset/\"))\n", + " if (name := blob.name.removeprefix(dir_prefix+\"dataset/\"))] # Drop empty string.\n", + "\n", + "def parse_file_parts(file_name):\n", + " return dict(part.split(\"-\", 1) for part in file_name.split(\"_\"))\n", + "\n", + "\n", + "def data_valid_for_model(file_name: str, params_file_name: str):\n", + " \"\"\"Check data type and resolution matches.\"\"\"\n", + " if source == \"Random\":\n", + " return True\n", + " data_file_parts = parse_file_parts(file_name.removesuffix(\".nc\"))\n", + " res_matches = data_file_parts[\"res\"].replace(\".\", \"p\") in params_file_name.lower()\n", + " source_matches = \"Operational\" in params_file_name\n", + " if data_file_parts[\"source\"] == \"era5\":\n", + " source_matches = not source_matches\n", + " return res_matches and source_matches\n", + "\n", + "dataset_file = widgets.Dropdown(\n", + " options=[\n", + " (\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(option.removesuffix(\".nc\")).items()]), option)\n", + " for option in dataset_file_options\n", + " if data_valid_for_model(option, params_file.value)\n", + " ],\n", + " description=\"Dataset file:\",\n", + " layout={\"width\": \"max-content\"})\n", + "widgets.VBox([\n", + " dataset_file,\n", + " widgets.Label(value=\"Run the next cell to load the dataset. Rerunning this cell clears your selection and refilters the datasets that match your model.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Yz-ekISoJxeZ" + }, + "outputs": [], + "source": [ + "# @title Load weather data\n", + "\n", + "with gcs_bucket.blob(dir_prefix+f\"dataset/{dataset_file.value}\").open(\"rb\") as f:\n", + " example_batch = xarray.load_dataset(f).compute()\n", + "\n", + "assert example_batch.dims[\"time\"] \u003e= 3 # 2 for input, \u003e=1 for targets\n", + "\n", + "print(\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(dataset_file.value.removesuffix(\".nc\")).items()]))\n", + "\n", + "example_batch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "lXjFvdE6qStr" + }, + "outputs": [], + "source": [ + "# @title Choose data to plot\n", + "\n", + "plot_example_variable = widgets.Dropdown(\n", + " options=example_batch.data_vars.keys(),\n", + " value=\"2m_temperature\",\n", + " description=\"Variable\")\n", + "plot_example_level = widgets.Dropdown(\n", + " options=example_batch.coords[\"level\"].values,\n", + " value=500,\n", + " description=\"Level\")\n", + "plot_example_robust = widgets.Checkbox(value=True, description=\"Robust\")\n", + "plot_example_max_steps = widgets.IntSlider(\n", + " min=1, max=example_batch.dims[\"time\"], value=example_batch.dims[\"time\"],\n", + " description=\"Max steps\")\n", + "\n", + "widgets.VBox([\n", + " plot_example_variable,\n", + " plot_example_level,\n", + " plot_example_robust,\n", + " plot_example_max_steps,\n", + " widgets.Label(value=\"Run the next cell to plot the data. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "iqzXVpn9_b15" + }, + "outputs": [], + "source": [ + "# @title Plot example data\n", + "\n", + "plot_size = 7\n", + "\n", + "data = {\n", + " \" \": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),\n", + " robust=plot_example_robust.value),\n", + "}\n", + "fig_title = plot_example_variable.value\n", + "if \"level\" in example_batch[plot_example_variable.value].coords:\n", + " fig_title += f\" at {plot_example_level.value} hPa\"\n", + "\n", + "plot_data(data, fig_title, plot_size, plot_example_robust.value)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "njD4jsPTPKvJ" + }, + "outputs": [], + "source": [ + "# @title Extract training and eval data\n", + "\n", + "train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"12h\", \"12h\"), # Only 1AR training.\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"12h\", f\"{(example_batch.dims['time']-2)*12}h\"), # All but 2 input frames.\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "print(\"All Examples: \", example_batch.dims.mapping)\n", + "print(\"Train Inputs: \", train_inputs.dims.mapping)\n", + "print(\"Train Targets: \", train_targets.dims.mapping)\n", + "print(\"Train Forcings:\", train_forcings.dims.mapping)\n", + "print(\"Eval Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Eval Targets: \", eval_targets.dims.mapping)\n", + "print(\"Eval Forcings: \", eval_forcings.dims.mapping)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-DJzie5me2-H" + }, + "outputs": [], + "source": [ + "# @title Load normalization data\n", + "\n", + "with gcs_bucket.blob(dir_prefix+\"stats/diffs_stddev_by_level.nc\").open(\"rb\") as f:\n", + " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", + "with gcs_bucket.blob(dir_prefix+\"stats/mean_by_level.nc\").open(\"rb\") as f:\n", + " mean_by_level = xarray.load_dataset(f).compute()\n", + "with gcs_bucket.blob(dir_prefix+\"stats/stddev_by_level.nc\").open(\"rb\") as f:\n", + " stddev_by_level = xarray.load_dataset(f).compute()\n", + "with gcs_bucket.blob(dir_prefix+\"stats/min_by_level.nc\").open(\"rb\") as f:\n", + " min_by_level = xarray.load_dataset(f).compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ke2zQyuT_sMA" + }, + "outputs": [], + "source": [ + "# @title Build jitted functions, and possibly initialize random weights\n", + "\n", + "\n", + "def construct_wrapped_gencast():\n", + " \"\"\"Constructs and wraps the GenCast Predictor.\"\"\"\n", + " predictor = gencast.GenCast(\n", + " sampler_config=sampler_config,\n", + " task_config=task_config,\n", + " denoiser_architecture_config=denoiser_architecture_config,\n", + " noise_config=noise_config,\n", + " noise_encoder_config=noise_encoder_config,\n", + " )\n", + "\n", + " predictor = normalization.InputsAndResiduals(\n", + " predictor,\n", + " diffs_stddev_by_level=diffs_stddev_by_level,\n", + " mean_by_level=mean_by_level,\n", + " stddev_by_level=stddev_by_level,\n", + " )\n", + "\n", + " predictor = nan_cleaning.NaNCleaner(\n", + " predictor=predictor,\n", + " reintroduce_nans=True,\n", + " fill_value=min_by_level,\n", + " var_to_clean='sea_surface_temperature',\n", + " )\n", + "\n", + " return predictor\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def run_forward(inputs, targets_template, forcings):\n", + " predictor = construct_wrapped_gencast()\n", + " return predictor(inputs, targets_template=targets_template, forcings=forcings)\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def loss_fn(inputs, targets, forcings):\n", + " predictor = construct_wrapped_gencast()\n", + " loss, diagnostics = predictor.loss(inputs, targets, forcings)\n", + " return xarray_tree.map_structure(\n", + " lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),\n", + " (loss, diagnostics),\n", + " )\n", + "\n", + "\n", + "def grads_fn(params, state, inputs, targets, forcings):\n", + " def _aux(params, state, i, t, f):\n", + " (loss, diagnostics), next_state = loss_fn.apply(\n", + " params, state, jax.random.PRNGKey(0), i, t, f\n", + " )\n", + " return loss, (diagnostics, next_state)\n", + "\n", + " (loss, (diagnostics, next_state)), grads = jax.value_and_grad(\n", + " _aux, has_aux=True\n", + " )(params, state, inputs, targets, forcings)\n", + " return loss, diagnostics, next_state, grads\n", + "\n", + "\n", + "if params is None:\n", + " init_jitted = jax.jit(loss_fn.init)\n", + " params, state = init_jitted(\n", + " rng=jax.random.PRNGKey(0),\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings,\n", + " )\n", + "\n", + "\n", + "loss_fn_jitted = jax.jit(\n", + " lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]\n", + ")\n", + "grads_fn_jitted = jax.jit(grads_fn)\n", + "run_forward_jitted = jax.jit(\n", + " lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VBNutliiCyqA" + }, + "source": [ + "# Run the model\n", + "\n", + "The `chunked_prediction_generator_multiple_runs` iterates over forecast steps, where the 1 step forecast is jitted and samples are pmapped across the chips.\n", + "This allows us to make efficient use of all devices and parallelise generating an ensemble across them. We then combine the chunks at the end to form our final forecast.\n", + "\n", + "Note that the cell will take longer than the standard inference time to run when executed for the first time, as this will include code compilation time. This cost does not increase with the number of devices, it is a fixed-cost one time operation whose result can be reused across any number of devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7obeY9i9oTtD" + }, + "outputs": [], + "source": [ + "# @title Autoregressive rollout (loop in python)\n", + "\n", + "print(\"Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Targets: \", eval_targets.dims.mapping)\n", + "print(\"Forcings:\", eval_forcings.dims.mapping)\n", + "\n", + "num_ensemble_members = 8 # @param int\n", + "rng = jax.random.PRNGKey(0)\n", + "# We fold-in the ensemble member, this way the first N members should always\n", + "# match across different runs which use take the same inputs, regardless of\n", + "# total ensemble size.\n", + "rngs = np.stack(\n", + " [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)\n", + "\n", + "chunks = []\n", + "for chunk in rollout.chunked_prediction_generator_multiple_runs(\n", + " xarray_jax.pmap(run_forward_jitted, dim=\"sample\"),\n", + " rngs=rngs,\n", + " inputs=eval_inputs,\n", + " targets_template=eval_targets * np.nan,\n", + " forcings=eval_forcings,\n", + " num_steps_per_chunk = 1,\n", + " num_samples = num_ensemble_members,\n", + " pmap_devices=jax.local_devices()\n", + " ):\n", + " chunks.append(chunk)\n", + "predictions = xarray.combine_by_coords(chunks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "nIoUfBxBAwqm" + }, + "outputs": [], + "source": [ + "# @title Choose predictions to plot\n", + "\n", + "plot_pred_variable = widgets.Dropdown(\n", + " options=predictions.data_vars.keys(),\n", + " value=\"2m_temperature\",\n", + " description=\"Variable\")\n", + "plot_pred_level = widgets.Dropdown(\n", + " options=predictions.coords[\"level\"].values,\n", + " value=500,\n", + " description=\"Level\")\n", + "plot_pred_robust = widgets.Checkbox(value=True, description=\"Robust\")\n", + "plot_pred_max_steps = widgets.IntSlider(\n", + " min=1,\n", + " max=predictions.dims[\"time\"],\n", + " value=predictions.dims[\"time\"],\n", + " description=\"Max steps\")\n", + "plot_pred_samples = widgets.IntSlider(\n", + " min=1,\n", + " max=num_ensemble_members,\n", + " value=num_ensemble_members,\n", + " description=\"Samples\")\n", + "\n", + "widgets.VBox([\n", + " plot_pred_variable,\n", + " plot_pred_level,\n", + " plot_pred_robust,\n", + " plot_pred_max_steps,\n", + " plot_pred_samples,\n", + " widgets.Label(value=\"Run the next cell to plot the predictions. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "wn7dccXO5R7C" + }, + "outputs": [], + "source": [ + "# @title Plot prediction samples and diffs\n", + "\n", + "plot_size = 5\n", + "plot_max_steps = min(predictions.dims[\"time\"], plot_pred_max_steps.value)\n", + "\n", + "fig_title = plot_pred_variable.value\n", + "if \"level\" in predictions[plot_pred_variable.value].coords:\n", + " fig_title += f\" at {plot_pred_level.value} hPa\"\n", + "\n", + "for sample_idx in range(plot_pred_samples.value):\n", + " data = {\n", + " \"Targets\": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Predictions\": scale(select(predictions.isel(sample=sample_idx), plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Diff\": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -\n", + " select(predictions.isel(sample=sample_idx), plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),\n", + " robust=plot_pred_robust.value, center=0),\n", + " }\n", + " display.display(plot_data(data, fig_title + f\", Sample {sample_idx}\", plot_size, plot_pred_robust.value))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X3m9lW5fN4oL" + }, + "outputs": [], + "source": [ + "# @title Plot ensemble mean and CRPS\n", + "\n", + "def crps(targets, predictions, bias_corrected = True):\n", + " if predictions.sizes.get(\"sample\", 1) \u003c 2:\n", + " raise ValueError(\n", + " \"predictions must have dim 'sample' with size at least 2.\")\n", + " sum_dims = [\"sample\", \"sample2\"]\n", + " preds2 = predictions.rename({\"sample\": \"sample2\"})\n", + " num_samps = predictions.sizes[\"sample\"]\n", + " num_samps2 = (num_samps - 1) if bias_corrected else num_samps\n", + " mean_abs_diff = np.abs(\n", + " predictions - preds2).sum(\n", + " dim=sum_dims, skipna=False) / (num_samps * num_samps2)\n", + " mean_abs_err = np.abs(targets - predictions).sum(dim=\"sample\", skipna=False) / num_samps\n", + " return mean_abs_err - 0.5 * mean_abs_diff\n", + "\n", + "\n", + "plot_size = 5\n", + "plot_max_steps = min(predictions.dims[\"time\"], plot_pred_max_steps.value)\n", + "\n", + "fig_title = plot_pred_variable.value\n", + "if \"level\" in predictions[plot_pred_variable.value].coords:\n", + " fig_title += f\" at {plot_pred_level.value} hPa\"\n", + "\n", + "data = {\n", + " \"Targets\": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Ensemble Mean\": scale(select(predictions.mean(dim=[\"sample\"]), plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Ensemble CRPS\": scale(crps((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),\n", + " select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),\n", + " robust=plot_pred_robust.value, center=0),\n", + "}\n", + "display.display(plot_data(data, fig_title, plot_size, plot_pred_robust.value))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O6ZhRFBPD0kq" + }, + "source": [ + "# Train the model\n", + "\n", + "The following operations requires larger amounts of memory than running inference.\n", + "\n", + "The first time executing the cell takes more time, as it includes the time to jit the function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nv-u3dAP7IRZ" + }, + "outputs": [], + "source": [ + "# @title Loss computation\n", + "loss, diagnostics = loss_fn_jitted(\n", + " jax.random.PRNGKey(0),\n", + " train_inputs,\n", + " train_targets,\n", + " train_forcings)\n", + "print(\"Loss:\", float(loss))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mBNFq1IGZNLz" + }, + "outputs": [], + "source": [ + "# @title Gradient computation\n", + "loss, diagnostics, next_state, grads = grads_fn_jitted(\n", + " params=params,\n", + " state=state,\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings)\n", + "mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", + "print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "z2AqgxUgiALy" + ], + "name": "GenCast Mini Demo", + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/graphcast/autoregressive.py b/graphcast/autoregressive.py index 1cf13247..624c8a9d 100644 --- a/graphcast/autoregressive.py +++ b/graphcast/autoregressive.py @@ -246,7 +246,7 @@ def add_noise(x): return x + self._noise_level * jax.random.normal( hk.next_rng_key(), shape=x.shape) # Add noise to time-dependent variables of the inputs. - inputs = jax.tree_map(add_noise, inputs) + inputs = jax.tree.map(add_noise, inputs) # The per-timestep targets passed by scan to one_step_loss below will have # no leading time axis. We need a treedef without the time axis to use diff --git a/graphcast/casting.py b/graphcast/casting.py index cb0fcdc2..9678f91c 100644 --- a/graphcast/casting.py +++ b/graphcast/casting.py @@ -140,7 +140,7 @@ def _all_inputs_to_bfloat16( xarray.Dataset, xarray.Dataset]: return (inputs.astype(jnp.bfloat16), - jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets), + jax.tree.map(lambda x: x.astype(jnp.bfloat16), targets), forcings.astype(jnp.bfloat16)) @@ -149,7 +149,7 @@ def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype, def cast_fn(x): if x.dtype == input_dtype: return x.astype(output_dtype) - return jax.tree_map(cast_fn, inputs) + return jax.tree.map(cast_fn, inputs) @contextlib.contextmanager diff --git a/graphcast/deep_typed_graph_net.py b/graphcast/deep_typed_graph_net.py index 93a1bd38..27a7a670 100644 --- a/graphcast/deep_typed_graph_net.py +++ b/graphcast/deep_typed_graph_net.py @@ -34,8 +34,11 @@ } """ -from typing import Mapping, Optional +import functools +from typing import Callable, List, Mapping, Optional, Tuple +import chex +from graphcast import mlp as mlp_builder from graphcast import typed_graph from graphcast import typed_graph_net import haiku as hk @@ -44,6 +47,9 @@ import jraph +GraphToGraphNetwork = Callable[[typed_graph.TypedGraph], typed_graph.TypedGraph] + + class DeepTypedGraphNet(hk.Module): """Deep Graph Neural Network. @@ -87,6 +93,7 @@ def __init__(self, edge_output_size: Optional[Mapping[str, int]] = None, include_sent_messages_in_node_update: bool = False, use_layer_norm: bool = True, + use_norm_conditioning: bool = False, activation: str = "relu", f32_aggregation: bool = False, aggregate_edges_for_nodes_fn: str = "segment_sum", @@ -114,6 +121,18 @@ def __init__(self, include_sent_messages_in_node_update: Whether to include pooled sent messages from each node in the node update. use_layer_norm: Whether it uses layer norm or not. + use_norm_conditioning: If True, the latent feaures outputted by the + activation normalization that follows the MLPs (e.g. LayerNorm), rather + than being scaled/offset by learned parameters of the normalization + module, will be scaled/offset by offsets/biases produced by a linear + layer (with different weights for each MLP), which takes an extra + argument "global_norm_conditioning". This argument is used to condition + the normalization of all nodes and all edges (hence global), and would + usually only have a batch and feature axis. This is typically used to + condition diffusion models on the "diffusion time". Will raise an error + if this is set to True but the "global_norm_conditioning" is not passed + to the __call__ method, as well as if this is set to False, but + "global_norm_conditioning" is passed to the call method. activation: name of activation function. f32_aggregation: Use float32 in the edge aggregation. aggregate_edges_for_nodes_fn: function used to aggregate messages to each @@ -141,9 +160,14 @@ def __init__(self, self._edge_output_size = edge_output_size self._include_sent_messages_in_node_update = ( include_sent_messages_in_node_update) + if use_norm_conditioning and not use_layer_norm: + raise ValueError( + "`norm_conditioning` can only be used when " + "`use_layer_norm` is true." + ) self._use_layer_norm = use_layer_norm + self._use_norm_conditioning = use_norm_conditioning self._activation = _get_activation_fn(activation) - self._initialized = False self._f32_aggregation = f32_aggregation self._aggregate_edges_for_nodes_fn = _get_aggregate_edges_for_nodes_fn( aggregate_edges_for_nodes_fn) @@ -154,24 +178,31 @@ def __init__(self, assert aggregate_edges_for_nodes_fn == "segment_sum" def __call__(self, - input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + input_graph: typed_graph.TypedGraph, + global_norm_conditioning: Optional[chex.Array] = None + ) -> typed_graph.TypedGraph: """Forward pass of the learnable dynamics model.""" - self._networks_builder(input_graph) + embedder_network, processor_networks, decoder_network = ( + self._networks_builder(input_graph, global_norm_conditioning) + ) # Embed input features (if applicable). - latent_graph_0 = self._embed(input_graph) + latent_graph_0 = self._embed(input_graph, embedder_network) # Do `m` message passing steps in the latent graphs. - latent_graph_m = self._process(latent_graph_0) + latent_graph_m = self._process(latent_graph_0, processor_networks) # Compute outputs from the last latent graph (if applicable). - return self._output(latent_graph_m) - - def _networks_builder(self, graph_template): - if self._initialized: - return - self._initialized = True - + return self._output(latent_graph_m, decoder_network) + + def _networks_builder( + self, + graph_template: typed_graph.TypedGraph, + global_norm_conditioning: Optional[chex.Array] = None, + ) -> Tuple[ + GraphToGraphNetwork, List[GraphToGraphNetwork], GraphToGraphNetwork + ]: + # TODO(aelkadi): move to mlp_builder. def build_mlp(name, output_size): mlp = hk.nets.MLP( output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [ @@ -180,11 +211,40 @@ def build_mlp(name, output_size): def build_mlp_with_maybe_layer_norm(name, output_size): network = build_mlp(name, output_size) + stages = [network] + if self._use_norm_conditioning: + if global_norm_conditioning is None: + raise ValueError( + "When using norm conditioning, `global_norm_conditioning` must" + "be passed to the call method.") + # If using norm conditioning, it is no longer the responsibility of the + # LayerNorm module itself to learn its scale and offset. These will be + # learned for the module by the norm conditioning layer instead. + create_scale = create_offset = False + else: + if global_norm_conditioning is not None: + raise ValueError( + "`globa_norm_conditioning` was passed, but `norm_conditioning`" + " is not enabled.") + create_scale = create_offset = True + if self._use_layer_norm: layer_norm = hk.LayerNorm( - axis=-1, create_scale=True, create_offset=True, + axis=-1, create_scale=create_scale, create_offset=create_offset, name=name + "_layer_norm") - network = hk.Sequential([network, layer_norm]) + stages.append(layer_norm) + + if self._use_norm_conditioning: + norm_conditioning_layer = mlp_builder.LinearNormConditioning( + name=name + "_norm_conditioning") + norm_conditioning_layer = functools.partial( + norm_conditioning_layer, + # Broadcast to the node/edge axis. + norm_conditioning=global_norm_conditioning[None], + ) + stages.append(norm_conditioning_layer) + + network = hk.Sequential(stages) return jraph.concatenated_args(network) # The embedder graph network independently embeds edge and node features. @@ -208,7 +268,7 @@ def build_mlp_with_maybe_layer_norm(name, output_size): embed_edge_fn=embed_edge_fn, embed_node_fn=embed_node_fn, ) - self._embedder_network = typed_graph_net.GraphMapFeatures( + embedder_network = typed_graph_net.GraphMapFeatures( **embedder_kwargs) if self._f32_aggregation: @@ -232,9 +292,9 @@ def aggregate_fn(data, *args, **kwargs): # that update the node and edge latent features. # Note that we can use `modules.InteractionNetwork` because # it also outputs the messages as updated edge latent features. - self._processor_networks = [] + processor_networks = [] for step_i in range(self._num_message_passing_steps): - self._processor_networks.append( + processor_networks.append( typed_graph_net.InteractionNetwork( update_edge_fn=_build_update_fns_for_edge_types( build_mlp_with_maybe_layer_norm, @@ -259,11 +319,15 @@ def aggregate_fn(data, *args, **kwargs): embed_node_fn=_build_update_fns_for_node_types( build_mlp, graph_template, "decoder_nodes_", self._node_output_size) if self._node_output_size else None,) - self._output_network = typed_graph_net.GraphMapFeatures( + output_network = typed_graph_net.GraphMapFeatures( **output_kwargs) + return embedder_network, processor_networks, output_network def _embed( - self, input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + self, + input_graph: typed_graph.TypedGraph, + embedder_network: GraphToGraphNetwork, + ) -> typed_graph.TypedGraph: """Embeds the input graph features into a latent graph.""" # Copy the context to all of the node types, if applicable. @@ -286,11 +350,14 @@ def _embed( context=input_graph.context._replace(features=())) # Embeds the node and edge features. - latent_graph_0 = self._embedder_network(input_graph) + latent_graph_0 = embedder_network(input_graph) return latent_graph_0 def _process( - self, latent_graph_0: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + self, + latent_graph_0: typed_graph.TypedGraph, + processor_networks: List[GraphToGraphNetwork], + ) -> typed_graph.TypedGraph: """Processes the latent graph with several steps of message passing.""" # Do `num_message_passing_steps` with each of the `self._processor_networks` @@ -298,7 +365,7 @@ def _process( # times. latent_graph = latent_graph_0 for unused_repetition_i in range(self._num_processor_repetitions): - for processor_network in self._processor_networks: + for processor_network in processor_networks: latent_graph = self._process_step(processor_network, latent_graph) return latent_graph @@ -326,10 +393,13 @@ def _process_step( nodes=nodes_with_residuals, edges=edges_with_residuals) return latent_graph_k - def _output(self, - latent_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + def _output( + self, + latent_graph: typed_graph.TypedGraph, + output_network: GraphToGraphNetwork, + ) -> typed_graph.TypedGraph: """Produces the output from the latent graph.""" - return self._output_network(latent_graph) + return output_network(latent_graph) def _build_update_fns_for_node_types( diff --git a/graphcast/denoiser.py b/graphcast/denoiser.py new file mode 100644 index 00000000..1091964c --- /dev/null +++ b/graphcast/denoiser.py @@ -0,0 +1,851 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Support for wrapping a general Predictor to act as a Denoiser.""" + +import dataclasses +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple + +import chex +from graphcast import deep_typed_graph_net +from graphcast import denoisers_base as base +from graphcast import grid_mesh_connectivity +from graphcast import icosahedral_mesh +from graphcast import model_utils +from graphcast import sparse_transformer +from graphcast import transformer +from graphcast import typed_graph +from graphcast import xarray_jax +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from scipy import sparse +import xarray + + +Kwargs = Mapping[str, Any] +NoiseLevelEncoder = Callable[[jnp.ndarray], jnp.ndarray] + + +class FourierFeaturesMLP(hk.Module): + """A simple MLP applied to Fourier features of values or their logarithms.""" + + def __init__(self, + base_period: float, + num_frequencies: int, + output_sizes: Sequence[int], + apply_log_first: bool = False, + w_init: ... = None, + activation: ... = jax.nn.gelu, + **mlp_kwargs + ): + """Initializes the module. + + Args: + base_period: + See model_utils.fourier_features. Note this would apply to log inputs if + apply_log_first is used. + num_frequencies: + See model_utils.fourier_features. + output_sizes: + Layer sizes for the MLP. + apply_log_first: + Whether to take the log of the inputs before computing Fourier features. + w_init: + Weights initializer for the MLP, default setting aims to produce + approx unit-variance outputs given the input sin/cos features. + activation: + **mlp_kwargs: + Further settings for the MLP. + """ + super().__init__() + self._base_period = base_period + self._num_frequencies = num_frequencies + self._apply_log_first = apply_log_first + if w_init is None: + # Scale of 2 is appropriate for input layer as sin/cos fourier features + # have variance 0.5 for random inputs. Also reasonable to use for later + # layers as relu activation cuts variance in half for inputs to later + # layers and gelu something close enough too. + w_init = hk.initializers.VarianceScaling( + 2.0, mode="fan_in", distribution="uniform" + ) + self._mlp = hk.nets.MLP( + output_sizes=output_sizes, + w_init=w_init, + activation=activation, + **mlp_kwargs) + + def __call__(self, values: jnp.ndarray) -> jnp.ndarray: + if self._apply_log_first: + values = jnp.log(values) + + features = model_utils.fourier_features( + values, self._base_period, self._num_frequencies) + + return self._mlp(features) + + +@chex.dataclass(frozen=True, eq=True) +class NoiseEncoderConfig: + """Configures the noise level encoding. + + Properties: + apply_log_first: Whether to take the log of the inputs before computing + Fourier features. + base_period: The base period to use. This should be greater or equal to the + range of the values, or to the period if the values have periodic + semantics (e.g. 2pi if they represent angles). Frequencies used will be + integer multiples of 1/base_period. + num_frequencies: The number of frequencies to use, we will use integer + multiples of 1/base_period from 1 up to num_frequencies inclusive. (We + don't include a zero frequency as this would just give constant features + which are redundant if a bias term is present). + output_sizes: Layer sizes for the MLP. + """ + apply_log_first: bool = True + base_period: float = 16.0 + num_frequencies: int = 32 + # 2-layer MLP applied to Fourier features + output_sizes: tuple[int, int] = (32, 16) + + +@chex.dataclass(frozen=True, eq=True) +class SparseTransformerConfig: + """Sparse Transformer config.""" + # Neighbours to attend to. + attention_k_hop: int + # Primary width, the number of channels on the carrier path. + d_model: int + # Depth, or num transformer blocks. One 'layer' is attn + ffw. + num_layers: int = 16 + # Number of heads for self-attention. + num_heads: int = 4 + # Attention type. + attention_type: str = "splash_mha" + # mask type if splash attention being used. + mask_type: str = "lazy" + block_q: int = 1024 + block_kv: int = 512 + block_kv_compute: int = 256 + block_q_dkv: int = 512 + block_kv_dkv: int = 1024 + block_kv_dkv_compute: int = 1024 + # Init scale for final ffw layer (divided by depth) + ffw_winit_final_mult: float = 0.0 + # Init scale for mha w (divided by depth). + attn_winit_final_mult: float = 0.0 + # Number of hidden units in the MLP blocks. Defaults to 4 * d_model. + ffw_hidden: int = 2048 + # Name for haiku module. + name: Optional[str] = None + + +@chex.dataclass(eq=True) +class DenoiserArchitectureConfig: + """Defines the GenCast architecture. + + Properties: + sparse_transformer_config: Config for the mesh transformer. + mesh_size: How many refinements to do on the multi-mesh. + latent_size: How many latent features to include in the various MLPs. + hidden_layers: How many hidden layers for each MLP. + radius_query_fraction_edge_length: Scalar that will be multiplied by the + length of the longest edge of the finest mesh to define the radius of + connectivity to use in the Grid2Mesh graph. Reasonable values are + between 0.6 and 1. 0.6 reduces the number of grid points feeding into + multiple mesh nodes and therefore reduces edge count and memory use, but + 1 gives better predictions. + norm_conditioning_features: List of feature names which will be used to + condition the GNN via norm_conditioning, rather than as regular + features. If this is provided, the GNN has to support the + `global_norm_conditioning` argument. For now it only supports global + norm conditioning (e.g. the same vector conditions all edges and nodes + normalization), which means features passed here must not have "lat" or + "lon" axes. In the future it may support node level norm conditioning + too. + grid2mesh_aggregate_normalization: Optional constant to normalize the output + of aggregate_edges_for_nodes_fn in the mesh2grid GNN. This can be used to + reduce the shock the model undergoes when switching resolution, which + increases the number of edges connected to a node. + node_output_size: Size of the output node representations for + each node type. For node types not specified here, the latent node + representation from the output of the processor will be returned. + """ + + sparse_transformer_config: SparseTransformerConfig + mesh_size: int + latent_size: int = 512 + hidden_layers: int = 1 + radius_query_fraction_edge_length: float = 0.6 + norm_conditioning_features: tuple[str, ...] = ("noise_level_encodings",) + grid2mesh_aggregate_normalization: Optional[float] = None + node_output_size: Optional[int] = None + + +class Denoiser(base.Denoiser): + """Wraps a general deterministic Predictor to act as a Denoiser. + + This passes an encoding of the noise level as an additional input to the + Predictor as an additional input 'noise_level_encodings' with shape + ('batch', 'noise_level_encoding_channels'). It passes the noisy_targets as + additional forcings (since they are also per-target-timestep data that the + predictor needs to condition on) with the same names as the original target + variables. + """ + + def __init__( + self, + noise_encoder_config: Optional[NoiseEncoderConfig], + denoiser_architecture_config: DenoiserArchitectureConfig, + ): + self._predictor = _DenoiserArchitecture( + denoiser_architecture_config=denoiser_architecture_config, + ) + # Use default values if not specified. + if noise_encoder_config is None: + noise_encoder_config = NoiseEncoderConfig() + self._noise_level_encoder = FourierFeaturesMLP(**noise_encoder_config) + + def __call__( + self, + inputs: xarray.Dataset, + noisy_targets: xarray.Dataset, + noise_levels: xarray.DataArray, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + if forcings is None: forcings = xarray.Dataset() + forcings = forcings.assign(noisy_targets) + + if noise_levels.dims != ("batch",): + raise ValueError("noise_levels expected to be shape (batch,).") + noise_level_encodings = self._noise_level_encoder( + xarray_jax.unwrap_data(noise_levels) + ) + noise_level_encodings = xarray_jax.Variable( + ("batch", "noise_level_encoding_channels"), noise_level_encodings + ) + inputs = inputs.assign(noise_level_encodings=noise_level_encodings) + + return self._predictor( + inputs=inputs, + targets_template=noisy_targets, + forcings=forcings, + **kwargs) + + +class _DenoiserArchitecture: + """GenCast Predictor. + + The model works on graphs that take into account: + * Mesh nodes: nodes for the vertices of the mesh. + * Grid nodes: nodes for the points of the grid. + * Nodes: When referring to just "nodes", this means the joint set of + both mesh nodes, concatenated with grid nodes. + + The model works with 3 graphs: + * Grid2Mesh graph: Graph that contains all nodes. This graph is strictly + bipartite with edges going from grid nodes to mesh nodes using a + fixed radius query. The grid2mesh_gnn will operate in this graph. The output + of this stage will be a latent representation for the mesh nodes, and a + latent representation for the grid nodes. + * Mesh graph: Graph that contains mesh nodes only. The mesh_gnn will + operate in this graph. It will update the latent state of the mesh nodes + only. + * Mesh2Grid graph: Graph that contains all nodes. This graph is strictly + bipartite with edges going from mesh nodes to grid nodes such that each grid + node is connected to 3 nodes of the mesh triangular face that contains + the grid points. The mesh2grid_gnn will operate in this graph. It will + process the updated latent state of the mesh nodes, and the latent state + of the grid nodes, to produce the final output for the grid nodes. + + The model is built on top of `TypedGraph`s so the different types of nodes and + edges can be stored and treated separately. + """ + + def __init__( + self, + denoiser_architecture_config: DenoiserArchitectureConfig, + ): + """Initializes the predictor.""" + self._spatial_features_kwargs = dict( + add_node_positions=False, + add_node_latitude=True, + add_node_longitude=True, + add_relative_positions=True, + relative_longitude_local_coordinates=True, + relative_latitude_local_coordinates=True, + ) + + # Construct the mesh. + mesh = icosahedral_mesh.get_last_triangular_mesh_for_sphere( + splits=denoiser_architecture_config.mesh_size + ) + # Permute the mesh to a banded structure so we can run sparse attention + # operations. + self._mesh = _permute_mesh_to_banded(mesh=mesh) + + # Encoder, which moves data from the grid to the mesh with a single message + # passing step. + self._grid2mesh_gnn = ( + deep_typed_graph_net.DeepTypedGraphNet( + activation="swish", + aggregate_normalization=( + denoiser_architecture_config.grid2mesh_aggregate_normalization + ), + edge_latent_size=dict( + grid2mesh=denoiser_architecture_config.latent_size + ), + embed_edges=True, + embed_nodes=True, + f32_aggregation=True, + include_sent_messages_in_node_update=False, + mlp_hidden_size=denoiser_architecture_config.latent_size, + mlp_num_hidden_layers=denoiser_architecture_config.hidden_layers, + name="grid2mesh_gnn", + node_latent_size=dict( + grid_nodes=denoiser_architecture_config.latent_size, + mesh_nodes=denoiser_architecture_config.latent_size + ), + node_output_size=None, + num_message_passing_steps=1, + use_layer_norm=True, + use_norm_conditioning=True, + ) + ) + + # Processor - performs multiple rounds of message passing on the mesh. + self._mesh_gnn = transformer.MeshTransformer( + name="mesh_transformer", + transformer_ctor=sparse_transformer.Transformer, + transformer_kwargs=dataclasses.asdict( + denoiser_architecture_config.sparse_transformer_config + ), + ) + + # Decoder, which moves data from the mesh back into the grid with a single + # message passing step. + self._mesh2grid_gnn = ( + deep_typed_graph_net.DeepTypedGraphNet( + activation="swish", + edge_latent_size=dict( + mesh2grid=denoiser_architecture_config.latent_size + ), + embed_nodes=False, + f32_aggregation=False, + include_sent_messages_in_node_update=False, + mlp_hidden_size=denoiser_architecture_config.latent_size, + mlp_num_hidden_layers=denoiser_architecture_config.hidden_layers, + name="mesh2grid_gnn", + node_latent_size=dict( + grid_nodes=denoiser_architecture_config.latent_size, + mesh_nodes=denoiser_architecture_config.latent_size, + ), + node_output_size={ + "grid_nodes": denoiser_architecture_config.node_output_size + }, + num_message_passing_steps=1, + use_layer_norm=True, + use_norm_conditioning=True, + ) + ) + + self._norm_conditioning_features = ( + denoiser_architecture_config.norm_conditioning_features + ) + # Obtain the query radius in absolute units for the unit-sphere for the + # grid2mesh model, by rescaling the `radius_query_fraction_edge_length`. + self._query_radius = ( + _get_max_edge_distance(self._mesh) + * denoiser_architecture_config.radius_query_fraction_edge_length + ) + + # Other initialization is delayed until the first call (`_maybe_init`) + # when we get some sample data so we know the lat/lon values. + self._initialized = False + + # A "_init_mesh_properties": + # This one could be initialized at init but we delay it for consistency too. + self._num_mesh_nodes = None # num_mesh_nodes + self._mesh_nodes_lat = None # [num_mesh_nodes] + self._mesh_nodes_lon = None # [num_mesh_nodes] + + # A "_init_grid_properties": + self._grid_lat = None # [num_lat_points] + self._grid_lon = None # [num_lon_points] + self._num_grid_nodes = None # num_lat_points * num_lon_points + self._grid_nodes_lat = None # [num_grid_nodes] + self._grid_nodes_lon = None # [num_grid_nodes] + + # A "_init_{grid2mesh,processor,mesh2grid}_graph" + self._grid2mesh_graph_structure = None + self._mesh_graph_structure = None + self._mesh2grid_graph_structure = None + + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + ) -> xarray.Dataset: + self._maybe_init(inputs) + + # Convert all input data into flat vectors for each of the grid nodes. + # xarray (batch, time, lat, lon, level, multiple vars, forcings) + # -> [num_grid_nodes, batch, num_channels] + grid_node_features, global_norm_conditioning = ( + self._inputs_to_grid_node_features_and_norm_conditioning( + inputs, forcings + ) + ) + + # [num_mesh_nodes, batch, latent_size], [num_grid_nodes, batch, latent_size] + (latent_mesh_nodes, latent_grid_nodes) = self._run_grid2mesh_gnn( + grid_node_features, global_norm_conditioning + ) + + # Run message passing in the multimesh. + # [num_mesh_nodes, batch, latent_size] + updated_latent_mesh_nodes = self._run_mesh_gnn( + latent_mesh_nodes, global_norm_conditioning + ) + + # Transfer data from the mesh to the grid. + # [num_grid_nodes, batch, output_size] + output_grid_nodes = self._run_mesh2grid_gnn( + updated_latent_mesh_nodes, latent_grid_nodes, global_norm_conditioning + ) + + # Convert output flat vectors for the grid nodes to the format of the + # output. [num_grid_nodes, batch, output_size] -> xarray (batch, one time + # step, lat, lon, level, multiple vars) + return self._grid_node_outputs_to_prediction( + output_grid_nodes, targets_template + ) + + def _maybe_init(self, sample_inputs: xarray.Dataset): + """Inits everything that has a dependency on the input coordinates.""" + if not self._initialized: + self._init_mesh_properties() + self._init_grid_properties( + grid_lat=sample_inputs.lat, grid_lon=sample_inputs.lon) + self._grid2mesh_graph_structure = self._init_grid2mesh_graph() + self._mesh_graph_structure = self._init_mesh_graph() + self._mesh2grid_graph_structure = self._init_mesh2grid_graph() + + self._initialized = True + + def _init_mesh_properties(self): + """Inits static properties that have to do with mesh nodes.""" + self._num_mesh_nodes = self._mesh.vertices.shape[0] + mesh_phi, mesh_theta = model_utils.cartesian_to_spherical( + self._mesh.vertices[:, 0], + self._mesh.vertices[:, 1], + self._mesh.vertices[:, 2]) + ( + mesh_nodes_lat, + mesh_nodes_lon, + ) = model_utils.spherical_to_lat_lon( + phi=mesh_phi, theta=mesh_theta) + # Convert to f32 to ensure the lat/lon features aren't in f64. + self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32) + self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32) + + def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray): + """Inits static properties that have to do with grid nodes.""" + self._grid_lat = grid_lat.astype(np.float32) + self._grid_lon = grid_lon.astype(np.float32) + # Initialized the counters. + self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0] + + # Initialize lat and lon for the grid. + grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat) + self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32) + self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32) + + def _init_grid2mesh_graph(self) -> typed_graph.TypedGraph: + """Build Grid2Mesh graph.""" + + # Create some edges according to distance between mesh and grid nodes. + assert self._grid_lat is not None and self._grid_lon is not None + (grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices( + grid_latitude=self._grid_lat, + grid_longitude=self._grid_lon, + mesh=self._mesh, + radius=self._query_radius) + + # Edges sending info from grid to mesh. + senders = grid_indices + receivers = mesh_indices + + # Precompute structural node and edge features according to config options. + # Structural features are those that depend on the fixed values of the + # latitude and longitudes of the nodes. + (senders_node_features, receivers_node_features, + edge_features) = model_utils.get_bipartite_graph_spatial_features( + senders_node_lat=self._grid_nodes_lat, + senders_node_lon=self._grid_nodes_lon, + receivers_node_lat=self._mesh_nodes_lat, + receivers_node_lon=self._mesh_nodes_lon, + senders=senders, + receivers=receivers, + edge_normalization_factor=None, + **self._spatial_features_kwargs, + ) + + n_grid_node = np.array([self._num_grid_nodes]) + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([mesh_indices.shape[0]]) + grid_node_set = typed_graph.NodeSet( + n_node=n_grid_node, features=senders_node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=receivers_node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")): + edge_set + } + grid2mesh_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + return grid2mesh_graph + + def _init_mesh_graph(self) -> typed_graph.TypedGraph: + """Build Mesh graph.""" + # Work simply on the mesh edges. + # N.B.To make sure ordering is preserved, any changes to faces_to_edges here + # should be reflected in the other 2 calls to faces_to_edges in this file. + senders, receivers = icosahedral_mesh.faces_to_edges(self._mesh.faces) + + # Precompute structural node and edge features according to config options. + # Structural features are those that depend on the fixed values of the + # latitude and longitudes of the nodes. + assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None + node_features, edge_features = model_utils.get_graph_spatial_features( + node_lat=self._mesh_nodes_lat, + node_lon=self._mesh_nodes_lon, + senders=senders, + receivers=receivers, + **self._spatial_features_kwargs, + ) + + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([senders.shape[0]]) + assert n_mesh_node == len(node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set + } + mesh_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + + return mesh_graph + + def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph: + """Build Mesh2Grid graph.""" + + # Create some edges according to how the grid nodes are contained by + # mesh triangles. + (grid_indices, + mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices( + grid_latitude=self._grid_lat, + grid_longitude=self._grid_lon, + mesh=self._mesh) + + # Edges sending info from mesh to grid. + senders = mesh_indices + receivers = grid_indices + + # Precompute structural node and edge features according to config options. + assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None + (senders_node_features, receivers_node_features, + edge_features) = model_utils.get_bipartite_graph_spatial_features( + senders_node_lat=self._mesh_nodes_lat, + senders_node_lon=self._mesh_nodes_lon, + receivers_node_lat=self._grid_nodes_lat, + receivers_node_lon=self._grid_nodes_lon, + senders=senders, + receivers=receivers, + edge_normalization_factor=None, + **self._spatial_features_kwargs, + ) + + n_grid_node = np.array([self._num_grid_nodes]) + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([senders.shape[0]]) + grid_node_set = typed_graph.NodeSet( + n_node=n_grid_node, features=receivers_node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=senders_node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")): + edge_set + } + mesh2grid_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + return mesh2grid_graph + + def _run_grid2mesh_gnn(self, grid_node_features: chex.Array, + global_norm_conditioning: Optional[chex.Array] = None, + ) -> tuple[chex.Array, chex.Array]: + """Runs the grid2mesh_gnn, extracting latent mesh and grid nodes.""" + + # Concatenate node structural features with input features. + batch_size = grid_node_features.shape[1] + + grid2mesh_graph = self._grid2mesh_graph_structure + assert grid2mesh_graph is not None + grid_nodes = grid2mesh_graph.nodes["grid_nodes"] + mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"] + new_grid_nodes = grid_nodes._replace( + features=jnp.concatenate([ + grid_node_features, + _add_batch_second_axis( + grid_nodes.features.astype(grid_node_features.dtype), + batch_size) + ], + axis=-1)) + + # To make sure capacity of the embedded is identical for the grid nodes and + # the mesh nodes, we also append some dummy zero input features for the + # mesh nodes. + dummy_mesh_node_features = jnp.zeros( + (self._num_mesh_nodes,) + grid_node_features.shape[1:], + dtype=grid_node_features.dtype) + new_mesh_nodes = mesh_nodes._replace( + features=jnp.concatenate([ + dummy_mesh_node_features, + _add_batch_second_axis( + mesh_nodes.features.astype(dummy_mesh_node_features.dtype), + batch_size) + ], + axis=-1)) + + # Broadcast edge structural features to the required batch size. + grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh") + edges = grid2mesh_graph.edges[grid2mesh_edges_key] + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(dummy_mesh_node_features.dtype), batch_size)) + + input_graph = self._grid2mesh_graph_structure._replace( + edges={grid2mesh_edges_key: new_edges}, + nodes={ + "grid_nodes": new_grid_nodes, + "mesh_nodes": new_mesh_nodes + }) + + # Run the GNN. + grid2mesh_out = self._grid2mesh_gnn(input_graph, global_norm_conditioning) + latent_mesh_nodes = grid2mesh_out.nodes["mesh_nodes"].features + latent_grid_nodes = grid2mesh_out.nodes["grid_nodes"].features + return latent_mesh_nodes, latent_grid_nodes + + def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array, + global_norm_conditioning: Optional[chex.Array] = None + ) -> chex.Array: + """Runs the mesh_gnn, extracting updated latent mesh nodes.""" + + # Add the structural edge features of this graph. Note we don't need + # to add the structural node features, because these are already part of + # the latent state, via the original Grid2Mesh gnn, however, we need + # the edge ones, because it is the first time we are seeing this particular + # set of edges. + batch_size = latent_mesh_nodes.shape[1] + + mesh_graph = self._mesh_graph_structure + assert mesh_graph is not None + mesh_edges_key = mesh_graph.edge_key_by_name("mesh") + edges = mesh_graph.edges[mesh_edges_key] + + # We are assuming here that the mesh gnn uses a single set of edge keys + # named "mesh" for the edges and that it uses a single set of nodes named + # "mesh_nodes" + msg = ("The setup currently requires to only have one kind of edge in the" + " mesh GNN.") + assert len(mesh_graph.edges) == 1, msg + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(latent_mesh_nodes.dtype), batch_size)) + + nodes = mesh_graph.nodes["mesh_nodes"] + nodes = nodes._replace(features=latent_mesh_nodes) + + input_graph = mesh_graph._replace( + edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes}) + + # Run the GNN. + return self._mesh_gnn(input_graph, + global_norm_conditioning=global_norm_conditioning + ).nodes["mesh_nodes"].features + + def _run_mesh2grid_gnn(self, + updated_latent_mesh_nodes: chex.Array, + latent_grid_nodes: chex.Array, + global_norm_conditioning: Optional[chex.Array] = None, + ) -> chex.Array: + """Runs the mesh2grid_gnn, extracting the output grid nodes.""" + + # Add the structural edge features of this graph. Note we don't need + # to add the structural node features, because these are already part of + # the latent state, via the original Grid2Mesh gnn, however, we need + # the edge ones, because it is the first time we are seeing this particular + # set of edges. + batch_size = updated_latent_mesh_nodes.shape[1] + + mesh2grid_graph = self._mesh2grid_graph_structure + assert mesh2grid_graph is not None + mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"] + grid_nodes = mesh2grid_graph.nodes["grid_nodes"] + new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes) + new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes) + mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid") + edges = mesh2grid_graph.edges[mesh2grid_key] + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(latent_grid_nodes.dtype), batch_size)) + + input_graph = mesh2grid_graph._replace( + edges={mesh2grid_key: new_edges}, + nodes={ + "mesh_nodes": new_mesh_nodes, + "grid_nodes": new_grid_nodes + }) + + # Run the GNN. + output_graph = self._mesh2grid_gnn(input_graph, global_norm_conditioning) + output_grid_nodes = output_graph.nodes["grid_nodes"].features + + return output_grid_nodes + + def _inputs_to_grid_node_features_and_norm_conditioning( + self, + inputs: xarray.Dataset, + forcings: xarray.Dataset, + ) -> Tuple[chex.Array, Optional[chex.Array]]: + """xarray ->[n_grid_nodes, batch, n_channels], [batch, n_cond channels].""" + + if self._norm_conditioning_features: + norm_conditioning_inputs = inputs[list(self._norm_conditioning_features)] + inputs = inputs.drop_vars(list(self._norm_conditioning_features)) + + if "lat" in norm_conditioning_inputs or "lon" in norm_conditioning_inputs: + raise ValueError("Features with lat or lon dims are not currently " + "supported for norm conditioning.") + global_norm_conditioning = xarray_jax.unwrap_data( + model_utils.dataset_to_stacked(norm_conditioning_inputs, + preserved_dims=("batch",), + ).transpose("batch", ...)) + + else: + global_norm_conditioning = None + + # xarray `Dataset` (batch, time, lat, lon, level, multiple vars) + # to xarray `DataArray` (batch, lat, lon, channels) + stacked_inputs = model_utils.dataset_to_stacked(inputs) + stacked_forcings = model_utils.dataset_to_stacked(forcings) + stacked_inputs = xarray.concat( + [stacked_inputs, stacked_forcings], dim="channels") + + # xarray `DataArray` (batch, lat, lon, channels) + # to single numpy array with shape [lat_lon_node, batch, channels] + grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes( + stacked_inputs) + # ["node", "batch", "features"] + grid_node_features = xarray_jax.unwrap( + grid_xarray_lat_lon_leading.data + ).reshape((-1,) + grid_xarray_lat_lon_leading.data.shape[2:]) + return grid_node_features, global_norm_conditioning + + def _grid_node_outputs_to_prediction( + self, + grid_node_outputs: chex.Array, + targets_template: xarray.Dataset, + ) -> xarray.Dataset: + """[num_grid_nodes, batch, num_outputs] -> xarray.""" + + # numpy array with shape [lat_lon_node, batch, channels] + assert self._grid_lat is not None and self._grid_lon is not None + grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0]) + grid_outputs_lat_lon_leading = grid_node_outputs.reshape( + grid_shape + grid_node_outputs.shape[1:]) + dims = ("lat", "lon", "batch", "channels") + grid_xarray_lat_lon_leading = xarray_jax.DataArray( + data=grid_outputs_lat_lon_leading, + dims=dims) + grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading) + + # xarray `DataArray` (batch, lat, lon, channels) + # to xarray `Dataset` (batch, one time step, lat, lon, level, multiple vars) + return model_utils.stacked_to_dataset( + grid_xarray.variable, targets_template) + + +def _add_batch_second_axis(data, batch_size): + # data [leading_dim, trailing_dim] + assert data.ndim == 2 + ones = jnp.ones([batch_size, 1], dtype=data.dtype) + return data[:, None] * ones # [leading_dim, batch, trailing_dim] + + +def _get_max_edge_distance(mesh): + # N.B.To make sure ordering is preserved, any changes to faces_to_edges here + # should be reflected in the other 2 calls to faces_to_edges in this file. + senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces) + edge_distances = np.linalg.norm( + mesh.vertices[senders] - mesh.vertices[receivers], axis=-1) + return edge_distances.max() + + +def _permute_mesh_to_banded(mesh): + """Permutes the mesh nodes such that adjacency matrix has banded structure.""" + # Build adjacency matrix. + # N.B.To make sure ordering is preserved, any changes to faces_to_edges here + # should be reflected in the other 2 calls to faces_to_edges in this file. + senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces) + num_mesh_nodes = mesh.vertices.shape[0] + adj_mat = sparse.csr_matrix((num_mesh_nodes, num_mesh_nodes)) + adj_mat[senders, receivers] = 1 + # Permutation to banded (this algorithm is deterministic, a given sparse + # adjacency matrix will yield the same permutation every time this is run). + mesh_permutation = sparse.csgraph.reverse_cuthill_mckee( + adj_mat, symmetric_mode=True + ) + vertex_permutation_map = {j: i for i, j in enumerate(mesh_permutation)} + permute_func = np.vectorize(lambda x: vertex_permutation_map[x]) + return icosahedral_mesh.TriangularMesh( + vertices=mesh.vertices[mesh_permutation], faces=permute_func(mesh.faces) + ) diff --git a/graphcast/denoisers_base.py b/graphcast/denoisers_base.py new file mode 100644 index 00000000..b30b39bd --- /dev/null +++ b/graphcast/denoisers_base.py @@ -0,0 +1,53 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class for Denoisers used in diffusion Predictors. + +Denoisers are a bit like deterministic Predictors, except: +* Their __call__ method also conditions on noisy_targets and the noise_levels + of those noisy targets +* They don't have an overrideable loss function (the loss is assumed to be some + form of MSE and is implemented outside the Denoiser itself) +""" + +from typing import Optional, Protocol + +import xarray + + +class Denoiser(Protocol): + """A denoising model that conditions on inputs as well as noise level.""" + + def __call__( + self, + inputs: xarray.Dataset, + noisy_targets: xarray.Dataset, + noise_levels: xarray.DataArray, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + """Computes denoised targets from noisy targets. + + Args: + inputs: Inputs to condition on, as for Predictor.__call__. + noisy_targets: Targets which have had i.i.d. zero-mean Gaussian noise + added to them (where the noise level used may vary along the 'batch' + dimension). + noise_levels: A DataArray with dimensions ('batch',) specifying the noise + levels that were used for each example in the batch. + forcings: Optional additional per-target-timestep forcings to condition + on, as for Predictor.__call__. + **kwargs: Any additional custom kwargs. + + Returns: + Denoised predictions with the same shape as noisy_targets. + """ diff --git a/graphcast/docs/GenCast_0p25deg_accelerator_scorecard.png b/graphcast/docs/GenCast_0p25deg_accelerator_scorecard.png new file mode 100644 index 0000000000000000000000000000000000000000..d9173acb409a56186988f90fa74def4d1caf19a1 GIT binary patch literal 171471 zcmc$mby$>L*XTizE~N(Pkd6VQOKBvO5|kWDX@<^0q;UW#>5>wVmhSHEmhSGM&du{Y z@B5wj&U;yt`6&|0LlEE>3Hce)1u8F5DQ7vRT#LnHg$y1 zxLRA;I0(Cn(ft)d7`T6U%|%D^R}e=_F*hH`%ACeM&V{dK>6!)P};@qNtCI7#?`>(=@0;vm2*_%QfZS6H| zZLP%rUmIMM>wiz-pTtGE9;)Mis-C~={;#({<%?qhG5^(&;#kopM1@F5FOd|aUun1^ z@7811$*5o5HwMWP$4HvR2sH!*5ykWH2W2Wl{UEO*WK~-7B?bmE zEe)#QX99vp|Hp%m_>;KuK|3tQ|I7FP#5}|Z>i!w2$Nu`bX!U4Q7$m9+b-GUe_7b|K<&wH8}TeRmkRe?DvL@<#N?uXBCz{;twvx>tXl zMB5O#@^nP8wz7WD+Qt``->%4O&r)YlT?eyRHP{@8>T=*dS!=YBUQa5s)NoOMP<`OJ zA<}SryM*s;=Y2i9ILD@cge7o)b5dJAsXuGoaJ6pJPOL@oA>me5TBm9$fP8KccFudz zYj8IrcG!f{Aa-+be^+$C#N~a}JHPyuYyaJQQ{u|;cRJMq4M(Bnkwz=i_z?>RcgL}? zFRXoWbVKjR*;|ix4DM%uf*CkSTsGs+*?M2=EOr>2{vfT^YDA3Nti@f?R)g6)=c&h-l{?q1_I-f%az{obpXJ#TNS*~(6wdrYeP2X?bUeSI^+v>s=S zmui4mW=ynNUbx+NsofAeaDztDd)*Iu->tyT@(rBqs(>~Wt?)*Srp~3_-+UAR^0})1 z4n4rNb@kP;EGYMEFgQt=e9^;nr02Ti57VvRNb#7XzQ2^~6jT&dWH0fXwQI$lzh1nD zUG!)*l&aQ^wx4t}rOxr$aMm5R*wpS-By!4V2$_K&{aV43xH;rR#S=Sy8~b4;Mr;Iq zKxiYWG|C&OkjA7v>Rsynx)T+N!oZ{3{n`WPUlGQL)cY%HIRnerttnBGh-YjlPZ7kA zcyW+*h=PI9{ppDh(3^_dbHvYbTR=hnxLl!k`iPNO5OgP@x#IbrT)6T0}@ zDcKr+&awevO9;GUIX#dnp$?A!%{~uzs>Ay;DBpJh=w9!AQsMK7UZ~w1W(^6uB=B} zf&Qh#R)mZX+oLI(^1i>tp6#0KnfWMw5^n_5(QZ5*Ohsrdx<}CAiZ4LQv?})wn(j}s z-O`DNsWZr56|=)H>U)xw$?f(QHRsKxu1-@uDK}c6^ZUDk?GrB7X6#5v&r++yD4!&` zgIbeeoUv(;dLuA6;pmTQIIS~E?0t6?ran{Sm)T<)tGB-2cz;*(%IsHTeo@1WqW80m5)w4|b(&X} zO0^&E0=Kq+K9wVpYx|68BpwJ+6I*EU?A@LJOERB|S(*{=2mK*(dpa8RlzFZ05N^^f zvT%gQ^$O(8)+lk($GbL}E_M$n1eg0WgZs5@eLfiryJV*c=;GaZW#IQfRcTaMC#6eL ztmv=TgWu0=CiPsLsIR)@YBEpNmVO?HUCfQ{m2^_+_i(St**xeXAAJnb>)o>Yy3@1z z0wD!sQ-ksJmy(7bXlWQ%e&?wWo3YX!&L6-gd9AAUSrUSf5;^$|tT~bxv-x=!hvqK^ zTpC89{Z*LFQ-8M49R|fF8KyE3Yc;urOUf3+aLr}!`|HKPltp3dLtpBfw3R=(n8Wy$ z3+|R$7inn(Y@QC^_1H(*sW%_Ukm{}C7u&mCo#sgFk+oGH2wea= zrEGw~-9B|qTpIN)dQ-aTO<*2z1TdKqKKW#+GNLW?TdV%AK@f8KNZ)NOW?;<1)qD;GPq2x+koj8bfJDhLv&z8_COvZs zOerhZ%6eO^x}I{BPSL4R%Lkj*&uvc07hD#-^j|M(Xe%8Cn9!_Tntan^I~DW6O`eF_pP`z~6&#h{I^1~*1>2P?4+r&%U*=v1z0 zQN!7smjU~q_w^>892qnUy9K7al3_+gNF#aeZMvCS2i)18gM8oWI7D{zyt0si_&2?f zG!(UQVvChD^du<8g_P`;_A;L3q7)3UHSd&;_UK15SXuN+DJ)K&Gn+j@M?OVW5fLiXKfdXULk zY%nWf%}=YDPd8dw895Z*H=e-2Vi^4~#~q)k2I;`6xfv8Ws8gUc)vT%bbPY@ag_H5- z9d4rHZDqr<sGrrmaUfcez`8Z8t^=)P zTpk8W=(5*r{%D}A-_|T*?dwv9BnGx*?s^s{4de(TzQf&io%*<))3Co+U{#zx=4Dg^ zsb?1ZhU%Wz%J_`%wzyGw?$&#{iLzE{>|8SXwb+}wLP_VO_`eF-LE znL*K$zLv}drZ(AG*4XrYhxg%nh4W|<&LEw`zBxRq4$)IrKltf~MpW;k1HsVl%^?); zj^pRCx+6uF+KWRHWU<4|Hz^l&okhNV#nF=}LA2SgKf9{`#fPc#d@HUuEMy-dlgMFiRXDs3N%JS7_2-9JeSDH(dP1 z>nqvbsW}f-x?Ft8N#p%>BQ{|+v-a1~I~hqsM!L@_pa2ntC1AClrC{K;Ng|SH5X9!X=ka9s~P!GM zwA5>h6j3Dw{1_-1NibLb+#ZG(pf)U6Por7A(3ZrSya-Dm?x52)J?Zz>FGTVClO`Xm zerK~zhx(YLK<3&^v?j!)*bG51A;XaPTEF&C+h&7H|6yfa(>x;{qtl!oLkig;@%yX6 z*o>X)X$G;|8JgAdeOtG%E^aaR_I7Pj=E?Izl#Gy2PYP-WwY0o5#Y4jTu+cc2QM>)> zjr|syj?5c8Uh~q;0CJl{_>eu!eh^FrsM3oblww_SUOwEElGbMeq!N3X13q4qzwA#O zsyQ~6yJLWv+Ry$-qu`0;jdj0*byy7TK|M@h+9axw-%tRJ)we`HXckwJ)>JAedM|_e z_EU2e`E`MN=VhooMxQh);{voqpYG`+NkcHIQY0H$^X%fCbEoDwjM-&8knQv4ATOgT z$NYQ#E;sr|jC?V5SXwca^x>i-Y{eTh@58DX1I0OYRvq15X^{#SskM#s8xv%Uo6c9^ z@!<#`)Sb+AklQAW;bZC_cpZ39qG=>Va(gGDktKNROu4knkhi>i9Rfa&{GlPr;;qY) z?e#Drs(9^JO=nLR028yqFY8wuR>liQI%nT%_Y-O##fR)XEIiqY>PmS7qE>2h6nq4u z9zhwAmz8Fw3=EY0$=mw*X$JUz$H+>+P7)Vg;Z;-%eof!=e@z;AdleY|qz7lmh+M5k zmEDDI?evuvx2DMv$7_U^q4=?K6?m^83b}QY6rKnc#`y3}x~ZGsOhlJi=Hsp-R0|j5 z+qBg+32FI;I6l9Io#+&zXY{Oi^+~iR>j*Qk(eUJB; z7>yA%(KDzWqnCMt5U65qdNE4=ObsiyjhWK?YhNg!M~g?a2S4hHF9dIV!RD0r%5x+i z7yYPKG5~7ocK_LD7#_azf{~YL3OBSP07*_$j(g1DvV}BB!DHO!1hgO-6E;r*eItHl zLZP=N;`O-v@#z&sxkb|&YUjKX3`5XV^iH&7EMZLB%RzzLB8`6E^q*h2|*tw z*isamaGsNWCxW^C3Ii_m$cQJ*?J?`MIyN#6)|~4~Xt&aIOnv4xO;}s7(|k6nZ`SLN zCamIdASVnddhKmE=DJ+StL}il!rAr(Im7^3N1?BCR`5W}E2l&!nVC|QzBwZGu0m#^ zTH$CVvCiLR`3?&ACG=Yma5i4LKt$*~&FNt)K9jEiL?Z*6STA+6#n@P~)|AkQ6)QyFP0e~82on;9${Vf&M-olgXLHkt(2!+GMdp4f7b*T zGs>boSanjFwVti8szt|0Aa`MGMq#=iKP9rfjbExm$1jsEaeAquNiSUxD7 zFN$nRN5?@;;s{UpU58;4#fuo(5{Gkw_DtdWfdj#p@$b8967HCQp_?WSwB1DKYjOXV zZ%~!A0*zM^=@`DKwSA$l1hLe1Lby47QB{?kpZAh0yrK&UB?{SxyliVzFdGeh-0^8Z z{MxE$mCBs7R+vC;XHos>SN46i-@vw9D64GXFfeb~@{nThcjwAw}t z<)Gk*jMA==g?Y(x(}HB_)e^Ixf7fo|z50|2hJITGi+v(&Id^}>%GUm(#8T>R2xXl} zi5vbp=FNI=tZ1{R@h_IBtqS{Wl;Yhnx9r8)7?aL)Ecj*gsaB^0Ha_?%gt@cB{%rbEHZD*-=82X?SU=h9YqBwxjiMVk!;O=f zYtYRQiFp)DhnX>`mDf=~doD;@b8%4{M@_Y%&vxynjKu8Z|$1;`2ZbbCd zN6NlSWe?MWmJc+1i>jT<_=QTOH3zM2$~rP3JTlcstZS9I$8N0=~tfyV!?`weVlctV*^g$Yj#!-p*JGY9*F!4C%$*%57;X5=jJ%ML~Q z2AG)9`D?2!V6f3!dW&VN)IOE(U5Q!N+sG|gbF=RFY})U`)N*m}ea%kcEHRgxKz0Xe zqfos!=o1-V`d7}5SeLRpAzC&wNO>ZllW5TWabE_4M^hU!7pwpEMqk1H75|B@@yKVC_@Px_ha;lnFUiMWP5X% zfoevfDf33BQrsS)soee&0;ep}A)b$AjM2?CCX%9yk`eL>G!-3+3MwhoFu=T805-Hy z+V8fc+Aws4k`Li3MLe2o!*T4UMR#PkIkbhxMe}K!$2ybE{L^~ShLhxjU9ve4mjC;G z6@+3x;pW#pnQCXn_j`GkYUpM{19OPSNxd(^jCtJGdYlG$YtHhz;u_k2UemY>uXrBY zz;9OLZbEH6>31YJ$yc;kZ5(4~ATy{3eTu&zM)x3pvS`r7vMah*&K`d{1<7f;{@6s) z*t+&~zS7!3wbK(cir%)W@T2cl`{L&%;W72ZzY>ikdI7r)szffSqi{l#+;-HQU(;{7 zmyvLOiy*(IVuC-8Vlvg1c`S`ibAl|g!c2^Xm^2X)KgPPLavUnARNM~i>Gz{ka+6XT zwZA+MRDNxr$##T>Cr?JJ`ALbZXcpa^DRWNHN%~8$#Hn^RbHnjh6B8xfa&8pXm`Hlv zS(y+E9t@TiB4}Pp#+(4d`1VLMI`!Nd@vq^^csho+2~VNvT-=NdCRW4LA};lb$b_Am z3MvT^yDJFF5$a8IsUq|J1ZL9=y)#82Vys!~%H)+cQAT3|_xdj-9^1*KnOz`zy-3+h zDQe4YcQuMPuPfrCHYN9kM-H2_Dy?@*px=LMGFb>qVKv&%S^^rj z5|@<6#E@IX_ie{l-YLF(3r4do0w>E+2r%7M0yZ!TtFnqpy2i>1DxG((0$(@768TMy z^gqC^)_PU$g!B`g-Qtk(tV7MM+h>2j9Aczab!0EJaQIb|MC3%F23Z{$r@v4v~YuR`-*=VL&*V3K%o8=@nSDhVe zJf_V~QRAK8#o<$fMx1?aUkjqCrkj;#rWV5l$C_v4wm+*vP<>15?C5Nw;Zh*cmoCq= z^!vYL@G9U|u0|@pdp?X1I8OYL*ALsyV}Y(NhpcQ^PWNR}75`?R9lSvIMg3=J|(Z=xaiMs3s=fAR6)uqmC6_nG9ANwI-$h+&CRK@Q5yIEHv z#{Bl52+j(%n8^IDF$C+$D{6y zzLA(WSjkN7bQm`oTTDB#X4XkI==c=p|C0!#ZCnCD;s@j9?>Y7EP>AaT+a?t9NzUEg zM4H*`ZZS1L$)i~4Eo_bG9XIz^@z>|+zOn-g;^CrK;FK132D7>|U(Hc)wNS1_7)=of zmxa)av@*}G**og)X_2xN1@Bg60e3-x>vb-J&phagAv*U4OyW_ zjSf?2O-_xFEQ6Xf3sz$7AqOyOd070+kH6}()-I0=_(uF5D$&kPTjZ&Xm&_qZRQttT zI`ro>S{bBN)cwQNm*QB}EE=uyFp;OLRUk8F`|UstI|Kr5IF2dW4GF;Yru-6**Gk`R zD9wuRf8Ycn6GAf%WP8Z#@c`9QCEZ0h9*5s6$8$1+z9M@#mkWq^eZ!=!PFh zg_NM`*+#Q5#?i|QcfKGy6WH`5k4@iB=6>c?Q0*=&3=;F+kx+q=Oi7^|s zWl^ZihDP6p&&DY&2z~0MbZlncVC-;dyw*DdS4E{a%@Z~8Y(7r{$WLkRVv8T3Gg z9TId13AxaxHvgKCVD7Bq$mv`U zQeV^XQl)kxgRcpCh;6{*R>9lE)QjdW>dD)16sPT9q4q5|@}UwxJp^})TZJRrI)N3; zkib~$Cto3!#oP&|m70tOZdO!!Qkir4NqgMv?~LqQl)K1Zek`2y)~SD6K8)VH*e5m9=_tBWBftq=-ijP zlH9POlPHt)iBix7mAuh9UP8ozEFW@Z);3&@0@h-ul$#fFlIcPVrhGyW1r?si>v>_0 zQGC?H%_8`ioCMcl<$Fzt;4jbbD?V-sA21@tBR~5R5Z-%T%uf@`YMSgvbTyy4I>jfn zKEv)wtV=Taq?wX8^HOiu;qF_P4=p+hhWV(5UYZsm%L|4N-|X5g%!wPSfA{RSTKEbj ze*$l)N$X{H@FMzteC15Q8p$K?lXQGx|F%O=GF4WX+7meMdE-|8ghkA%y1ID0Zvvjg z-*z?;;+J6-`fQScfDdocx~C{?&O`m$HSDBVF{M$fXAh)54^GYrP8j6m$Pw4 z@JTp(JUJ?%LveqWf=yj5SpM4HRNS@FfIYmXwNi$gn-KlF8`oQ#Pty{4VlSsT@~cgKQ+sZdQrB!F7TOs zbNjB@kNk1-jbzT-khl6NGa5NJXqBa7Ki=D74-mp2{FxY-Rj{DPH^Mu=R*K*y$hE+Z z>h^3GSj1Jb3G*+Rin-`feQ7 z+9QbL-e<3s%S|-QSaP_LK;{)&(ZQFgG2M-#HCsK=9q?z**1L8QBNH^Qc|~bP-glOv zdrOa-9Le6_uC5uT>bVjU_nvJ3HS2Skku~#33J|_SA2%q}bUDvv`zGX~3O zw?bo*kNkGp<*z!Ud5A#QWD_E?U)7>{7)frj8ioS`yp+W=1xCFc9JxK*-a=vNOa=4Ix%5U=e()$*(f!+p<8v8&y4f^x77jbj7`RAUTh#EB7o~lo8Ti)EK#tTvWy^J2*yJ^$aNDm@Mci1T z+JA&4C+Vq&(`NP;1p>kmMRJ=e$ZnFC#BmpWZ9_>!pik0|Il7B6^*3je{l;Pux=T{t z0}CDp$svU!q00DL<_?ibv|w~(ZTvT?oQ*7vB{~C5*n8mT&955!xCd6M1f~pqr@F}t z)5YCGkv#)xzi52wvr}N14SW(mVj9T$C=QlEcPe3oOJO!g=`y#38=dLsP8;LCS3X$- z*vq?jm@>q8W!HaU1m57+;u;aLnyy7lg+!~&Xh=M_71=G8>ROhE?Z6e`;Xm_&f_;mD zGiy=j^Ml_WAdI8MH2AiTZT&c<^ArT8q{ygikGi6<*s3`9eUmAd%@%%eIv)4!o5EJ4 z)XcB$C3y;4odZy|uiK!jo{W4?R;DMTWlDaJWs&WJdRNLx~O5N|pA(g4u;JWo`%=LHacn<;|F-Io0g|nMB1h z_9W^sVS@{4Pj!EQOfE%2me~_BrAJcrxL9^AbRb&OSW&2KaOAgB0XfCEZV;nu|E(!j zrAJ+X&#Ad1iZmLm%cwlg4wP@ndFjX`%o7IETcQ(5ozDu3anlBC6-LU=;1F?O_t#hWX+3 zrhfK6&u_^b!0|~6Zy0W)CJJl!KTHu1wSfHmxm~f<>yl3u7m-?d;-%rXiMgVq+19?} z&5VXb_v>KZT!U2n+0Vq(z?nN%l;{}Zeh^GHg00PjWiI^-s-U>#~TY*z$W^O0)T(Bf^|jd zalbK3@8_3KR4liiox1c?7hrs%JGIWVHOTNMME9EC%tC81Z9sj?=gpk0>zm=)g|otz zQ%({&3aB&$2U8*yyiayg0uDA<=*C{bc6&Ik#um{y zzIwiR$~D@m8Mn}El7eZ8?x*_e{ytyfm>50^eQ$1QhSZ+Z?AhdycguY)3LqnT&_L17 zqYejEjHxki46U=?8Q)WD2i-W6Fly>hhz|Io6b8OPxS4<7bur}Us0d+!RA!_5CWAKKMqsy^t;XH<%4%Yz%(^ph`_ zqtGo@g3!q}-qQ~UyhrEuE?X~`Iy4c^7H5&pFrWn+6Z~0KV?gzLiH*@Tl+Rld!&IHO zp+N%a()8atv*w9|N4GK8Tn8IJ&sdY7!QkTm&8hRjUbJKvozMvO>MTkfOk7q zGK_|;@R}D0e#&r((%?|CD1G+*XR>9%QtHr% z)3>pL_dkQM&dAz;6#}`bKC3HbSC+Q2?`o#ZUD$hL9oYgLjaPZ9{j&@so&eF4-XE#` z{C)Qt203DxKY>2vUfmkRqo4b+2Z>ie4z81Hr#ze=elg` zhwrv|5$ha9AJw)k0;^Svh4aU5xN>6oBqH#a@?a=YB?f)F&abpiulH+ht$!H?q4u0L z3-^&Pe;TRiuvQA_p0cH^2hKIH30z>apA>2B@wCz74+N`Tb9dO|)O@qKnH}?{5Wi6I zP#Tv>c24M3FC2^td^M?8?OD#pRL=cwefvjPZLZMk39X+}>)N%|8f2Ol98X#<5XzVo z8jRWDl7`P{(soApd4l|2l4GQ?p^H%*jpswM{T~@cO@Yj!>4U%S1cA)JvU6`o%DN>y zI0RIX=D#sMCJ@2Dy=_9lcLYpdAGPY75p!!30Vrs6Rfxoz6ulU#noG{=)u-O$6+H%5 z=4pdHZpvpi-Z*XaRqltjhlT$YTwW zNIyBa=6E@HAZ9gJ^3TVo`3q0kSn;ct>szqcl-$q3d8(6fYF+HciPqwl7exd%rLdB^ z?z$IYZN#24%3+sbc!k-&dpT=Qv)0qig`<|dulFi1CgxXu4=Y=vzbHA0cBk&-uBJvJ z!&jvTzvfwwnN04Zn28=L@14m?Emj^)*t(njYkG)RrzP4>J)fD-@6U9Z2Ty#XEdt=9 zl`k((hP`!PXM1A#uzW_6MIRyyxCOOxs?9OF0oR04&x=mTUMqX}g_n{z>ty)lX`cxaoaO*=`z)UTzjZJBjK zTis9XTpv3{qY5_H)SXFe!K*!#9zqwTxE#r{^nda<=fqC={%0b#M2ZAf)a2fE9arOm}J z)|yiFHUMa_dK)be+^Oe#&tjSqOGgY0MEI4fp-O9RQL%^_YOGAOqR!H435`E=7*Svn zQa@7_;%Z-#y|vL=hCfBcFv5$j3kE9nqz19(D7OphE_CG=U;MuKIbe)|?;KscDY zbIjh+t5*nT24AWUJgFeUSBWaIW-vn=NVMr@U6aDmnX!Su_|3@8w5ze9 zaT<;^p}sftc}zi0g`?T@ZK-l{wata(T3*|52}dzpV9~Qoiy$l=n;%79aM}sGaY)ze zY`oroHBLc*EUvZga$Bo$J51v45|6)f5mwv!lpRfBw+7z%Sf)pNt?d#`)}V<+1|TPk zKfa`Uj_HeG+H0sdy|d3<&p(oP?+uNleJcW1 zIi3dJK840}Dlek3kZP-Pl9**r_UoWAhIZ48_OFv2{l>gA?0{%LrYIKeUq=f-8cCH; zprxz>xU`B$P5y;8+Zxhav|WG$JF^b@4Hc;WPE{6N-mg!_Rf7c9>G$V~VV<-|01sic_U=4V(J#4q>Vp|u0i)$1&Xu$5R;ywG#F zI=sK{{i(}w&30bCaOO);th#Ky+L2MNAme&QO`rsH>JZO3MZEu2#{-UOyA|aeoaJ=f zr`zR4)s>$*O#E`Z}$)@!SndiY{DDLQonta|U{R&ninKt(B<2V(;0mrAfeiS~t1$B#K zuSE4N_I;(m;~_0s=zsqqd}X4RATA7}uO;SSOB)=c@Hh4x5ROC+ilZ~-nNKMvt6Z~l zPZ7tUX-4yX!N?w~rS~fIcE<}r--L!^CKX?7(+3HZ2e`Ol8l+Kw-!L91(g(S`dA-{!+h4}9(jMl2nCUcp)T)+_a(NswTI+Lw zB3&@p+Xx{*nDvKbP6BAK(oeiUuM6ZcP*LA%9nLun`fDwU4?&q9R^yHsY3E;dc6_Q- z5=M*2ZZ@_4p}t@0p^k(IZwGvr02t-J8F0f7^i5}%*)=%B<$Uoq_&HAgaOl#|fi~mm z#*fa|z(S;HlOX7t7Oi{e`NMhVU;H>22bG~)5Qm63@ZS{qpC?VOO#oykW^1SL4a0}$ zk?A^Ghw3+HVoX4?52!u21S>?qWPb38H*M_S8Ph0qnCsQECjFC&&)t(I9+vNfI|L-| zCEvUrsLHGL#3B`ft9S^8VYT){VOHtvj@n7z*q}w%{2fcl61B}FX<|IT1<2_yp*E>~ z0;$})6rV2u2kH9g^jbOh07L6)4CVsR=NrvQ0L`2zDD7d@XSIF+MJ2fXiT;u6-*xiq zNn95A+0IkulmBYme~bFBP%kl*SRaja3+5iOLI2MAzd`B&>F*v;_&cYcPyZ8nuOI{L zoE7{Hg0cUR;=f7D=SU6+bW|i2^Pebwpio9UDy%oJb^inE50`jA4!kMzQ2&E~hH4L1 zG_@tIWgw7}CMp0&fdzWm%N%dTF8}zd^Y!+`0+oVGAP8`ywFm*PJD|T?5`3*gL!D_; zUxB`S|G?2c^lV~+ILwm~@4FM~%HT;SfX4&Co0LQ>%{0;9eoX*;$Cp0vo;=lY zL?Kiyu^{4r`{Jg^`?eQ(1~3NDUqq$ew29puw(bM;wBsAGvuKaIlhgzKjozEsW>$`n^VNKqwp)c+n5L$oN+3A97+uGw10WDRacQ}0#*kxArIF})k_Zx4FA5gQLIzZZ9LVbwC$1;DHJ zUi2;kV0(i^!>%cp1CUpEYRk!w_*n@HEZOky~cd- z9rnrw#dkKQwE%1X{#lWCy$-;l&i{qUpJog2^;nQQO);aW&z+d`LDeaqn%X|2X%Y@S z>jIkthjUKgPFg&IY0=4a()vxA{WUJ{Y8}U6nOii^vVMLCTdy6R+iZ4XancJ(Y>w4Y zWq(W?t9>;MBAFhy)xU6fk-gueky}&%Qi3N7f6%^1(tm!H4IE#$*zaG=yBGinepIjI zv-S^Vy6eJ&t@?>A;X-F)>SE#Rch?6O*N0jFI0@^&NyqoHBd1Ey3^!DPp`&Vbe?avbx~?*d2mk03e3X9 z-Cv=6N8ebn=-`*~BkszJV}rBgk{LX0Nk>AM15`a*KU%R&t4h_g@eZ+rZT!Uynz0z- z*W@T;STm@|Xr<8yoXQa*1g@p$m=Y23_zA!G62IHVS*ni&3?7GQ)fqiHKhiC3C($8d zEY^r0e>f5)48#m0#s^#*vhI&NZ9P}WlZ$Eh!aY~%O}SzQNAGG*l?CQk;VkQT;+KY8 zst_8!AC0$rw)<=bS0NrZN1f#UiQ6gGglk#DVM6Pmiw%kUYE@_Jl^rX{Fms>Mndj-S zTx~(cOy0ZM>rRQgr7*!|jE}wnvGx>vW=X#VtT*IqA4q(gKuTwjlXJ|%3#qid^T~Gx z(!xc!lJ3wL&390sIphj9BZlM@wv3n=k4K^ew+0zUbjgN_I0i^@@^6n_+baqXoXtSI(L)O4mO=p@e@vFu-e}bcN*p zY;r5nwrfuI1ed^yL43^5qr4Hi^Huk&p}lQSlgX1|_!~1=Q@QK8IQtHfx0TqJtNrhpou*!PlaN z>0){iHUV`KmC?1tU#vpZ0c(PsI@5WxWT$tP07C!fa-d`! ze;&BZH@NND8Jz{3SL*kBGz6=arIC`uH*_gz$ zMP>w@yGMlZoN>5V-Rd>a`0DQ)^r1J8YukSB*X+1^1x9Z@z>^~TRmQvRgCCuf z(|f%%@W!}CIQlZ+kW&ZXTvjVC2kZx+O#UX$4M67?QLi65vnFzyiup>@Jo z&gLleWo)))e#~2>aHH*Ueyj97 z=m+<$EU#H4mqTvM^cEb-_F-k$O2YgvcCotl3Rz(?H(5OavjCCHLpyL7+2YOPL_PzN zJOVuAs=!%${@g)W>X|xm5jo3kgm2j4)%#^)L4`4p(8tIUyd)aCzeiR4$+C{G_*%*Q z$$FVTz2Wc1OC!4Z2spUOcKrBcfIq?1R&TIxF!wTl#+M9vcSdp5jyar5 zFZg0Le?nt>E4NZt4v z3G9j)7jfrd+F181TK`^pOWk^-2l(aHbfED{^4L9vhr)%NCR%+NCmai=U(^E~zHbns zYt~3gcoTJ)ya2775KOs&j25alj==0}}bFyY8C=-itmSPlGFK6mP<;3%EtpG4tjk3y!){ zN5zbRj#mk(bFTg5dyNtnC4au|O_NAQ?zq$NlCs>_>^Jr&phY8#e7yV)b&|JM{OV3V zI^(+J8;z6K;PvV;P-3sgbTX9lEiu*aoH4-{%u&#!jupt5dXqa*BeSbO{>e1It_`}V zgQ02`QA3^=PANYn8i_CXND^_|1|bnCyZnytj|di7Wh=}%q-r?nt?PxDUy{ItlnQ3#a}K>QvZn{9nh#GUC50bpVct zwb(dFZh-ge0^zr8Dy3luNQhw6Ry(m8Y!RSeN)^`76bK^+hs$wios z{Au%eYor?gI&gG^DC=O-xAQRtEdS3NQNy*%WSWrbdIY0Lf0a4egdd zr-_v@yFBhLRNpPo2XLPx7w9=?GL(IC1ME=FN80Ltr8oZs$Ka@LniVmpuHFMYA z!~K*X*I9i{HZIA3RreOJ=6~uD%^#izGLfIXyltawSzpFqY~^|;j>^CuIsv72UX=P0HjB)yuX&Jr69WA&&xKtn3 z*L7UBbHSFLmq@vuWB5jH`6PX#f;y;^cDNBMi!d7R=aGH(1ooG|T7Db}sAq3w1xd>A z-}WoKSV!w-JOZ391_*n2v2@Ym!OS|IV!n^xhP>B853~g8uSM_WI})n^T| zGK={IkX1~txlb>&D{srENj=s?{iP@wGss}CjixmspU7UBsB0@_7t2?-Kiyny>=&l! zaD6sduFOq6Av!q-QluR6Mmh6E-$Lv02Quum7NvUk(4;pqGnuKa2SZFkHyL%@Yuf^T z`6nWuyn|1O7SJKfxKBGDCgW9CT#~i3A)2e+rDw;KFbHeb4QxYX5mTS><_8ty`Cw?` z9EG*4bq#|{3tv5dJiK+^BYN;u+$Bb=gANAdmZN6C7T#|COTk>UwasSAL-rUbk!D9h zft&f1bzEtMLN2+V8QhLpA;1Zed}cR~DwPCH?XS<7NN^g(x|`8W?10;#vI!1-Dd`r^ zs}qp!TFuQ#Ii6n5hP=n{b*DgE1LbMyIcb`njt(4R%f2q&|NbNCIlWXDc~G;w(zpTO zCpM6*3@$0-HD?bTQNigp97pe!0`UZ4=1aVTmN=}{Cck!Xzv3UduIkP1?nPE8Q_`%` zmgDl+yCz%P2W%y9-el#0OlX~vm9eBN;jT7)OAtUUOX60+XH4;93mkysyTK=b#r7%S>J2R zg+8hnefl&ph0CMIs~4bf$WxRra(HEHI{nTeT7TvNqadZY3ygc&=*L9zMK(U5Y-CF3 zZA;oo22{!uWn<66i5+WY7fRqUXWAzyPVqdhe#eH0e7(_AuZmxWm1bjwp(RMIA}sWv zjvBWtkM8uz)+R|?C{BRG zaWZdxt$(2T_u!MKasPw8w~nfETib_~MvzWPB~&`3Tco6< zk!}P8q@`23q#L9=rMtUJy1S$szPa4Ty*=-H-t*u0-*1dP#>Tatwbt{@dC$AA>yD#F zc$>%^+<_Hxy2PdX;CPBXls2r+f{{kIGoqD8N!n`=JQ8+Rw_~G~DODj=rFOrS>7Yel z9v&-`us7Qyk!04}db-NOTq@=Tpl5(cp0?=YEi^}!8%o%k_7tHfIUR{LN)@;MSe=e2 zr?yhTe4^lL6qY4(p-(4)zEh%I=p$>?W`_6uI{`D;-f>iQGxqdZ#_!SbX_Dn%5AL5b zt)DxeO+BxxIOE!F=!|QRY3ROr(%L--l8WUQr3d+hYjEG7Iuj_yQ{Gjy{Mq@E#1DS6yulWMb5Na$=qaL-(|9P zKuv$jod<*r!5*b;dQlC=8*C`58$?0}9eq@rt~nM(YSDOeoS>^GoQhyOU12TD=-Q-= zUU4nlTwkKaD$xS&w(+HyRpnDX1C2Yu{BLS*m{@r#H9B!Jt&R2 zFNSz3eFkE}IJ?y1X+_rs369ksopYMf`+yo#>Y+vi5{Wod@XTk4>Y3~;jn8MtL|gYq zHWNb&FWXgp?6v5#=}Hn%UwC7AZ$hZ=Q}fOG2NZM2}MO z>5`U6C|E?_W+e3*RJe&N39=Pt$ba~}Qa$y!R z&ZvILmK?9EFZGEOggRct{j_JgdUCnTf>T*SnXw3zA5@-_v}?+CChp_XN$T_VpAya!-^Svew>6t5mK!8m_qYHp7=~ivguE8jE8}7? z&!Qb66}5QTv;*LmEBKS$!z6&&03H_-oR$Fh6{iooPu#``aDr^M5smPdzXO&}tRq2B zYw_p_gWyxl@Y67J3q6C3!6_IVF0S+tMNQ{ZI=lFx+F47KOt}_u)b<6r!{sceu*6o_ z!n@*b)(Va3A?s<5T8+is>7;12?^5#|@Kc*)Y+SHH7`f z_mX|iW^;4aArzww>k=bGowICGvOA%EZya+~-W?y0< zCoM^biHKLjpT(BG(JHqTWD;tNoYts=zkelsqk`p!7cc5W^70yW&|s`B_JiZ+y`AXr zEi8*|100xI-lvgb$pwO@fv{||p9G_^=H>F+H-Z*!pCEI3td|L4LGcY_o7} z&xwg7gFGgdYu62p9ut9Wy*b8r1*BR062lMDcIms3zwi1+Aewyt>GKM)(x8eX?3wO# z4MvTkH( zS1eV1Y<1blBsD3E2~|80%Eaj-a=kf>blrvh251|;kB}UGiZ{iRHw1B4V#Q?UY^g4v z4d}RM|D=Dz_!PykqlU1K@=P`PX7xGOltHpV?%@p=oBR~+>r8skMl@my(gXu1H(YY# z8@nkjdkxq71IUS23$u^4NSg0&#m8M~+~1!K9)B!pt!ga0d@J*S_q2e7jfz(&{$r_F z566x{*0(A1L-?{AT&~OWu}Ed!2z>dpmiCh}cIJ1BmWlN5-NMNFT1ln)o}n0g+bZ3> zC=P!GFZh@VsWS=vjffu9QqS>GFZ}_2Y~q9PpG{EALJ9Y*qzc-R^L0Gm|O&j z+{|Hha~4jHeRJHe?nHfOyFI*?xP$3yt~Bg&y|;lsm}62;^Koi5rU6w%L(aKZYKU<1 z^A*JyY`X7>Wn1O-K_lOJ^-59?JqbKMlI3N;x}+CldwL7@CF=_tQ%8W#I(Ghg7gLkU z(&uzCb7y&*)!)D#XlYe2OD;&b=YsYqpju!VkNhKhsFr&tp@i%5G@Flv+7xnIB+^*m zrwYo`rVjy17-rBNePVu=bjQS)#4$qRSFJG|Pp?Oz(nkX__;3;SA zi%e-Qb4!)m;e93aV&;>AV_ctkjF;?VdBuP&xt+ce0YKostXna`$v?HC#1qAae3bYW z;@ukt9A{8lcOA1GMbTfQOByj~CtgPNET8-9rmv>@9PZ%_02-uLHbf}aW@d}r)3VkpIIrl?Xg-gNu7<~Uz zjUR?-iowoLS6EvFC==t52j+vTNw8mXjEg_j%DuBPrSRHHl5B})yzOV44UOA|6%K!9 zb;N|6y4?gSS#VwCC;s)P%J&&lzT=FHhM4qEkV>gmQf4**$Z7g8&b~apQtibNXAjOp) zU?4s8kQ>7F^$)iSC6W9|2Z33l`Lci)-Hw~+2Ggj{j5;u*g++KzEH zFLQg@^#}q6rTZ&-pF){*99_5<@U)_w(WspWeE2r(U1|h9YUg%@>l_^i|6`-=>y7uj zob$LUgGw9Ub>zGaY@L10U&@@RvGRGhane!2Nx&n65WXf*aOUE2&~;zO`;qRe5r4*S z;b@7P?mtc}kJ}1I;j<=OeWuPWO`rHBWn%pdY+bv;yt7=ey!2lwWX72cq7LSA2|3TI z{(RjW{N20WFz{H@rN~n?{&<98=2NwrDR-8&Yq%5v8;MWiy}(4Ae$ZFmg=TNw@8>W;{Dd4{ zXsLnJs9Hj(E{!)e=}TQk3kT7jp&W^O?HeP#gPfgwa*c)tVnc5^FRV!H#>@uPND|@`++uz1cfx~0^6t%GOu!QL&ybwL z2ungUtOK=h?Sx=NT^0j%M~`E@u_aQgJ+Hn=r`YXdH8V!_vd`-C(}a077z9mEi9@m+ zcx1+0iuayN;7`b$M_W#ZpY}aOsWf{h)=KO_oh#~an3SqUvTeg3&E+=GE;21z*k&R! zu^?K@pGh)sA9Zg*K$-E(JI`i~8+1%0FLF9x(%a0qUockS-mKk39Dj98@^>yub{Un9 zVWLzx)i`6z-}sD7s90yZhi`S6=Q)#N6E4nGCEdF?e^$cTM1$I zJq}^m<+RpR;A!w3X3@-_VM;5+P^tA1Iw%aOEEMney?#bnXl}{b6Zzw)+i)L-0cDUY z5zYWQkPffiSH*RejO1|A(CYaIYjW~@O3y0JI32iVR(0 zvZ;gme4AzvfWeQ9M(@S|{k?=?RM`x_R6)ZkTxDGkiyxMnCaV6_bx9Op7vA z#nU3Z|C#Pn2SsIxrodO8TIw?E>&5QHsFX{SC~R( zKC2Df%NIUB>V|tJq+sGTJNq?&8vEYbR+kwOI`yXWDZw#grBUDK^hxJLm2+4%Wwyx4 zrshP?$b=fgo7HW~Ngo73*R^dc2?pGkp7<33d^RU`?Ib!Kq?}DLR?Cex6Yrm+Y%j

F$vaCW$+*?5 z-~~cDB@O!w@h!s~e`bRh>y-HEraC%bTXW3imaSnEB_XV7lUrXt(980Q>O^s&9F{Wl zyrd1eW$hQCzqhWDzvge0p+U!6ek(RrN-DX8+Y;d_x72!`z_gMuVm3`TX<9XbhspVh zp-l&;e?@xgJlAB|B1MfuHn4#G<2V98PHZN7>o8gJFTpc(10l!e#=q*ZB@g8oYF;cq)WtyD2d6 zF{9NnXnbpJFmN*CBuhhI!{pKlI=KSRsnoQn#u=IT$!0-5l(5I4H)3=tv{Wl_mE4=q z0ewV9nR%x@Aw{mAv6E1Jc&+exR^i#FIBtozTNUp1VPQ|c3@hLmpB%A~a4rpUmB-_fnlcWS@fKPNgxR=ehZ z8_Vlm{)t95bBy&mK~VN&c#)X005LYo*Gt<@hL+-&|FbX3A=lVGo3 zsu7%m?v2$?&ndtPpfg9mz%_FU@M8~kLVJw*FyAB7x3w0_A|?cT2f5Q4-83aXDS#UGvg;V2729V8xOx?b+Fd;GMABt!Cu+OhAgPSwH4Sd5K|uZY}4+ zbV9WEl=n`AJ{n)P$0D4h1c!n8t~u6$X#q>OW4OIN(`4q)cqsZW1bU2`lsHg+3M6f> zrAzLsh~-X`SITW#rqh{eyYbk*Bw~?rhToJS))YT>)Si?_sx{PzDZOktfx|JFD=i*~ zM8<^H$KemZlXxz!9Z(@P%UXp6VFDCLlH$8PX`wu6i4ozYvG+x+lyB&t{;)gV$ac7Z zgsY2#MLDkgLt&+beCUyh*B6%MmcrLKl-2S-7FUjKH^2Evbp@zzh_6r4@Zol>QU%(- z!^2Ex2Teyvx-FEqvIa9ww#ym3p8(kMi^V3At}zybN=_S2^HFWCtg_GH!7B9^T;;4T z90&C~)7HJ6zL@MUjL=-JIX`V?16LMuRxIZO8P#Jf!du!7obM#3jI|w{!!8+f$_vvi zZ6u%AroK}H#YA3>9)ZLuF(IYdi#dR7tXZz;7uF8!Om{8ASzn>m4R@%JyDa={%Cy-g zw!nUOv4W(lXuJKDlMUTG`JPE}%`p7?rhjFQCoKZk<%>>9g4#zY_EC*_cZ^PXL@P-o?VPOH(nOVvP_U0GLmRLy(V5u{8St<>bWmzh`)`OH=X_fdhsFmo(8+~mbibSV4a$6;|+WMJG zDSJf~qYOlsz9*+%2fKq-ZJ%SN&3`OO(D~Yly!|+hk)p)-eOY*HhX>C8F)wnl#!_wr zGik?|b?C=!SG=k=3E^;Aan>fasVRi4hDiKp+V7+;^HvNdj&(lVUW1@EUV;nzxrC=n zcRG?u7G{>Ie_MF>`Zgzu=Be4Ud06CcG1;TJpKKc1_dH&~si(&cC$r}{*SCHr>iGQ5 zEbS>7NLhSjx4cc4QU}})H{MQ_r@f!GWCNxShpoA7DdrTf$-c2(>AZM=oa1KG*qFna z+nO(#VVD*F)TDcX5-D2`20K`LDY{3Ah%G##t7Bf6XuV^!TZ?I;`T5#&mcdi5PL_7E z!hFglknuOuJ5D}XRx_7#{1GyO6Wk0E=Tm-7#g>5-V8LVm1{LA2he^uzTBI{eAY+#Eebi9 zj-&8Vg`^s~l)C5`T zw=l`_4&J77h&UI^^3hG+@^fVE;VkMWyil7%`s$o)rG z*+fl0w~Oy2bc&ejU3kw_Ojv`g#!TdDGqz4=!=($^lWnJsx@pQncveFNDV=odL~zNR z@=CX?MQ%&}T<_dW9h8mr8DojA2pRieU z9LraQETzJ-vY*@}8C@+v=bxR>%n5mB&V8s>V^vJ4t}85O8>UNjOI*n80jRc zrL4HVAFNE6pL=)=>Wl&Pg7d)2;$=mE8{WF9khyEeQc0D3{xsph?7bU zt^al#)=7mZuv5uC$Jscr`4ujH ztTXr_9hMd1O=$)5`hFnwzvAi@w zhOX5>3Pd>Wg|IvJN^Up>Ih)K_kl z+gJ6?1;GaYq0;C^{PDB}8`l$}`=@G`lPIeGpr5d8h40kuD$^1#Ne1ahfB|#JXc9w` zyZmE9b#he3?+=wGIO{9VOyV^MgM>!B!wW!D?5s9RACB4`+xmK^TQfGR#e2il%mLQn zfwlsoSvvwRq9~gsN-`t8KpgW|(EIeR6YrL@g|9R^)<9QWLnmx2@uH3cT!Q3-`4oTk z#6wy4y!=vsww};3W1mvx(ZiL%oYluZDdhM1S0B0H7ymsI_*HkyXtgQtrTUx-eI(`mXLFb#`vIx^#b1+z78L-W6$xZARzt00W(;rS&$WexD(pJs%)&aB_~(;HTF& zRO)oT7a3+e^@^Y!A>;^%SFrz}OT*99A9FtM_Q*T*wTls(vez5`$Py%IYDPxe;rmRe zrwRfDYv!dN)^rP7K47{A;^TQV|HOGTnj@F~mC>QR8IIs{j%s?u&bywJ;o@yC zC#HMt_E&+E6)%p2nKw83zI;~*#L${cYLL1s8BvPv$U)hO$SJTli?C8lv|PUl?EO3| z%md2L%>-L8$jBqOB}w87Z|u2tsC?~%8+^6t)@QMm)P`$!G;0y? zMs6G4>wyzS2gGL`Z#q7Yrmb~~O}tZ1H+`Ew1B#MHjnDmBOXTQ?1LlsUPpRQXJjh-TjB6!I?Oqcaa92dBR63-hVx~?6`~_49gnCFGu=hl2 zFg*rF2ae2r0`!L^X|HLole)?4Nr+knG~=bE+C3ej13$I5ocH6)uZaLDt}q6G#G~?j z-cj1ONuKBP)&gR%Lu>^?%-2{t^)1n{_$SRO`Z(bwI1)zyj^T!e#{(sr{mjUlku+Ud zrnSi?*D%Sx!{RYE>@9?ty>TCNBbfTV??iwXKR>C#$4^AU7()+PaHNFd@GqSzq3V~8#TvY*o+T>4{F9jKx^y%?VGj6$oMsL&NeEhl=@Oed+?y$$JyMZ_VS4#qA|!wPmd}Gnz9FDCcGz2G zIr{{v`Ll%R13>{i7qD{>V9OY;NI<15<}Uz=`yE=-koZR|VAl50QiB9d-&Z zANdLh1Zs~{h%rqFy}kRjr)0tpD0(G7fHc3A2(f+o3}}9jg9%^2cWZ($`e)!9^2;c^ zdCHK1{!TMtWIV0!FKkQG{Zj^ztF8gKNukeu3@`xomgs(t+;3aVjI4$I3*^HPi=fqP z62vm`2^9Ok)Pio`x{OW0>eBa%%X+QWw~)SS(ZehwQqdAZFh27kEl5qSV^uFo+UP$Z zPnu2yv~Wfr%Sj|JXOsUIqzIr(fAVlkkC0{gt@M9Eok%c4usR0Vxy%0z3L`@XD95wW zl^p)d4{F6hc*$Soy~7syFTV$FiVk|JQxt zu|PmPPb@N?1Gu;@435TUZ3sMv{2quk>p5t}Z=tQ-Ej0s`6bn#?^aFLO4TsEH@^f{f(>pTj?AdedUmTz}E5?UIa4N0zz$M*MVi^X~W z8Pc;|z$RFCGXN4NX6~>!p9nq{yMuGNG7Fh;;iQ*A5AqJaT%e0!7DAN&YcfT2-719{ zOJL4g(nB~Uu6GUjxR#Eu=g)cVKs%k( zCnS0ZrS2?P)4c{Tn4=0p*56!xgxj$NKtr)7qgdf+-4GNEF{r-(%+>q@LDr`XQ&_$J z#BN9czB=BlCI<(O@MBrh49KX>0)qBsshF{v(qqQMpSeo)Jw=GxoQmiY!V_6Ge%#F; zX91OYn1Hfp3pY0ilhuQ@pTK)9T>-EWi!A^|f=46uYx*dr!~u*9DNN=ID2VIDhlnyI zXnY8!0a!trc4}9Gx$2k$t07}O=72W3vu3^MzqQ2Zv+o6yIv(2rSl-p>Y=Y)2An|!W z*tQxwx}_l@cvJi$#;AU z<9v*i3Jn}$xW7?*kY%SNYpOZuA`7whA++?>godLDWp&z7UD!5{C!IS7P1IE&joQ>4 zcmsB39n#)o2l^JS7ciyRpTfld1VXh_11WN^bZ{O5s7#q=GzRKHTP*;H)r!;CkBMX2 zaw~B|=+(`l)oo)Jq1PAoDi`+Y8gDFJp}osjK!g5%GT+hH5Q(3?NPR;PwxAVW08gz2 zS@X{NeVd}=9nIwxhB%g2mF6imJHt}x4v>2mhG?t6LKkcegg_D~Q6#W)aovGr^;G~Z zcJ&2&QmE?e2*98?wkmfo&MF({)B!qrA5d}EUW09!KW>Ymo}yuT+M)qEUDq~ilK75e z2d(h})k1i|?19^HvzYjoke=+>$%Zo)yW7*rEMB{61?`cw_b7j8g_tD-!cR>PHJ0g{ znjBEJ8X^@{Ub0JG+=6{ygV0e{NxK|{-$ivp5J!j$z4U-qHf{#cTlY#O8ltp8+8)j~ z(?U3x-JT*ahY0Ryw3TnF)wPv?*vcBu6U;N|uDr?ZK3|RESz%wfz`8FW*C@IJbIXlX=7?7-g14$_S9nBZ+7dE zMNS##iasy#3+2hMF9210meR=`r^HWdfc0?{{t(>k+X#}q9Yhvw=;g72Pt9(BN$YQQ zDNojP6-wRiuFRF{$k3_PMsoECpu(7ar1sN-IjMSMkSx|+h8u4|DYN1D5`qe=ROUo+ z=BQv6nu+$|L;yqGkvpnWlVj`Rqsb^M2#>NIIWFUghn^1hQzHDr_{9S-e84sG z9C0BR?<*$vw$PvUDdjm7A@2de*Acnm@p1raV$Qzm0iQl-6XtXnfR7pq44EHDf3Lx3 zIcJaOMR&wjza^QN4Mtuh+a%dn0omp{PzkY38S(@tSrB4IoEiY5@@q5KYvu7?_ahp#4| z33EQ@xj0CF#K(hnRLGm35SRS$&w#Mt`}Oq@IA+4qs~( zOn7Sg{LGE6gX!D%S`-Us{V&*xKQ$TwRp4hH4h!}OFWog%f9ZgF-DQ#q==TzEuNf-= zu$+jlGVI1=K7A=E05~1IrGgoe$=w#_h`EKfp#SB0fr7u<=w_9tRkivk+ z#d`+q9tygI5CsGfOy-u9$ggaw0luTvWDs(u=dUk8GIImqApPm?@pKR;i#!+Tq`yKd zl`Nc9RNAQ7&6u*8x7WPY{@htc{9CMk-I2G)12!3e0P>bn!*h&9h5$JaK;Lz&)Ci$9 zs%H{Vcn~ORIwm$JcCUQJXXhWhe`viP()tP8-3DP>Z>h}14Xi58UR%9_Xs#KM7bV09 zrZJOYYPdP-b$r>Q<&;TPIL1aJvj8-5jHS z&hRK2c=3kdk-2vym-$e$%VrQx$ z@Y;Fm0RE9BdNg|j0dVr8)mhokZ+KP$nH{bWiM7o9yYYsGivE67FfmI@U5Rqyjpy7R?nC?=8$+B(BuGfLq3$akJy)3HFrcI=M zQM}h7(|J&9tIMHF8g9Ylnw5KxjW4(wKMDZi(F`P(ZSu8+APehG2^_i~FtnH0K+%t& zZ??bF{Z=t$5MM1P#nE2B&QfHQpczxO+XTK}9^p#)R;63L!FO9CJIm@*B<3u93dm}_VPDVS)pm#QE z#oM-TJnEq)a#esi-d_!uPSSc@)5Ev$Zbq(C3(03<0k|e%!=T&Ccpqr`@mK zz@_>rM5h{yqLiVn)-fr)z1(W^!Fq8PGEk;+1xV7y*}S89<%Eh!b?Jz?RNbZU`<_R2 zR$HH_q+&Q-MaBI8I7kemXuhasJNYHkO&k&k$j z()Ui#o$bn|5U^gLtb~}1CXDccWU2F-D+q~f0J&8VYVRqg^~Mb#AF<@S-IA^+sKM)< zNWFgB24$fSYb1o1x{6Qjsiq_#qDwl$cQPc{ng>~-f6mj@8`249BF*G|=IKJM{J9?Z zMzzRaGkzd$!el4}pIT20sz(Qo>e_jP%3Q>t`z?{LP!Ey=4W40z)FWpBGHjXmvRk?a zf`wKXmbhSDa5}Drbj9xjb9<(g)9hul1>D}zLr9MDy3UU9Fa4GcVbOc^=>(8gxaR~t zP!bNMd9#6xR1VHqt6@w#T8Bc_(JZPq2dA2dKpdkeBZiWC-iYLA`=e6`O1l8cl7(Uj zxA8*AXI|`272(Byk3hPm6g4KeY>xFN4llNueF-uVZ88QAf7Up!;0ckx)hHbFmnC!H z*;#U#o)N3Sj~_h@xk~{tN9A!L z{J$N7b%Z8efGbWES*nyY)Cs zoQ~ru(9b+9OX~EH1OK;G>PP=P=r})$XbooCy$U1Wg(3X8{k$)ii@XiOJKUq8ZGZi0$OwiRD6VY_#KIV=je#hm-+s&^M0!MyA4dOp+U=tO;rI|x^fC5Ti_ay{jqKR`@T<7z8 z4`|x{`$IxiihO94E39~pP7YvGxn#=F)~+i4-54{#EUVH({Nu{+YSF{Kd?>}JU@)CK zQ|Y9JqYX-#%%N2`MKm&&W@6Lk`+QZYWe`9d&w{vaB~3VS(4V?c`8ob2h<>%aA#O3U z`2(s32n~xCIx5oW5D*wtbH<`%OAxgM?*ra}{QDgxLlg<&qhmm6-QCQ+dSw&L&SjLM?i0Qs#<2B*bmWL|$dp;5`c<_sX?Wiq5o_7L76!Lvmp2s=TT z9`XkWQ;Itkj6vR2He@9c+%1G27-20u81ZE3XpZ zbk{xsVEV7Ah;^p z=f6PhXlQs&*2s*1xx5Tt(3HUn&F1D6ZkyF?QJ#YX0W?>;;Ky_5m>78n7ue z5`_gp$eiwb*MOsLV8H+d30taw&S~c`Cpu>yU{5<%+j!XhmJxr4l!EW?lhtJX#;qI& zz8W&^PVo6Mgp5P84JV&k%7~v{LJmu3K*2Icsx)Z65_n8 zYX%~6Vx#w3OqN%-z$59t2Pp$SYtU+$1&L@@=Cs^eMD@N^kI9N(w(sMpeq4eK84G%q z{8r)*oxuFq9pYacc3$m*=3qMjxUGkb=On@iRP##$8HW0AcY8 z7l=d8TmX*6CWF){ocgC`*t;`i{q#CQht_w_Al4z04a^q-d>?d`U?N(rg5vvjaRs`{ zSxZwrAqz+rI0!waJ))77EsG`Sq0I1a`Ts{r{k^VW+@m)?SLu)bBhrOGUi`81bqf&cea5O3lCt<^6f`~OQAB#5EP`j`S% z4#3C=2?MC}#phcB!T%2c;}N4CCkKmmx)n$HRUo(*)g2nx)oz z={kPooj5h6tI@x&ws7b4!+x}AN8z;6aUP51wubbPx$k>P|5WL;t$`-`|6L$w=G-fsXIQ8l3R|_Pu{!=wAAqKeS{QKkXqD zny4TYHnpO4{d9@0$k@E;^!Jy>JdwZe1biY^@&O-HS!eRRAZa2Znx4eQaE3XW8-FZ1 zJ7^7Fqxc;oe8-MiM)k)__?LH=C=ewwg8Q)qOMc@bBf5^?cx> zXB>#B`1%1<38of>?rvWwWVc8osT9J1u~UV6{DgdJYD)IoYuFr_H1CtGQIVsgqgBvl z?1Ay;EBF)xlbzx*W@|MN2bI5g6( z7)od{w*MhNzb_1Z0R@APDMVcV$9Detrzj?Pe^ifj3+~T5{Nu|ogn*o-cBSF|A985) z8oVEeOg{H7!TtTizr~q96#+;Lmi-A>wLgFO_vK&-!299u=>-4jUGN!EEN~M1o`q=q zDIb6QN;;yu=SVt<(M-_p?#HP*`?6Obe=wL{i+29&A=ZSn+WoO*^Au^{S6k}a=e$9M z!4L-Wti36N{A+KGL?7^#t_o0eDKn*cRSZ2?2`1wD$htd6pysmBuvkiE6j}MB!BASq zwDM%`naNJMIwy6n?c=;UNyqKCN(^)o(&TnKPLuDM(-R32b74%?Xi`B(1tm2AH(O$f<;2eMFo`j}Cwfg<)l(=EbG@;tU*`J;vAjb9a$C zu9DVnW0s?Z%8X}sI1$-u)r+-1LTELkq8J0_Mp`i7@y)sw@^Bh7Xjl~Yah`WR%X9wz zNzPt;p`b$bpv3R{9Op*FVC4%Tn}aXVm9gY>7`FfAjo5cHE605vYr(2fm~K=LLr_Bb;4TW^sIEG|xkH1`Y!Fzf@b<)6Wqlx#I+4pRs>WvL?dkS-hJJ4p zlhehgDY-tf{FTm-Xjao{(f&A=#NgRkwOCeDmBSMbq`%KRX=nt@PqJ`JWrj!pl3sUm z%Fqdc4zX~`Zk9Y@?gT?$| zT=>BQ{s&@j1muk-3lsa2cxGUqU!Nb)bJf?Qz#-$2);pZ~b4EbJqeSpJUv+Jd=Os4G z*E{xtsEi=m5WBe*%YG zh1nbxuDbPC`4`0{n{g~Av$?GRET~8+owN%TlFQHhbIZR|GM=?Nfx{}@W@kdG$RKrh z=$mI>JexG7L<}0v$1#Nv)0v9Mjv#_w4_KshsYFgGgSDQ>u<-De6Z-e1Z%WI`)SdHP zPoAiV-e6F36!G|JLh;Ygzg~`B+@%SNdK_3F#!f@PXObMCjITjOi1S`SYbo<6szKFM zdewVQai+j(fYXfA_Mw0B1qD4lz1yp2v2REVZbUH}+jwmX{8aOgCLFDId3Z^`iq663 zBiPC>-a9qZX@x#J40xC%Fvh7_aObNqKV)Oh5+O9b4y}4^dYi9uLa-i~d;!Z>x|tI_ zL5w2 z5_2_5_k)|mX>&`9&7Oh}g$ssKQVTu$1_z@{bUUy5?f=llq7R`izP@A-jY0UA$neXe zi|&x>TBK9&Nprcpb4RGU(<#W%?F|0%<`ms{E?Gv2&UkT%zP;Oed#sz~c;k7!enZ&%e}t$}R#S)`9rW-4{T)n90M%J|x`x1XQyE|x~y~r?OW&X zrKyK={;kVHbHMo$Kd>{{Cj7SW`ySF!+?xc1)tezBTf4yobhgS==~}0*r#_;`>-`zP zIF~wK+rFU>wdnLu)W~0 zEY*?lmx((QG9^?Qd(;_&oW=P2?4XH01N)dkf_&38QkF{b`(-~0v=7o9faP%p7PB)D zpA58(R|FEUNk!4BQtnI^MGx}btMT6a6vR~?L+-7amOdv3lezCV!*Xxd*&4rsIaKPr z02pm~aG5QWwhC(1VPM3(hCzSan#VWTpr1-qSJYTYNu#>`yany%8;v;Q^F`+~9NPc* zG49B~gStG%WNvWXybky(*8<)bO4{K0eSc`8tWfSAfv#ZW_}a8dQ6!iGC=9ZAu3hG? z*Q!RdRTg6nMx_xzuK{|QyJw(ds9YkJNyv6@n!>#9PSWS{3)O?3u%ruk!jdIN!>eC zE+*>yy>I-aazJ*LMGSh>vXcLL(OrTJ?M2tcTC`?Qv34uDN@@7=gY6Y!pwn#f2w6nm zYpgB#kVHL%hO6ow*TP4JhB3_XK^g2fL3QvsM`K2Rg};|5tioMV&4tlymX z2KDPgLNJL#M&(v*JWT}EP6Km_My-lr%&aq*$PxfPtIi?a>)9HpSEoB0^L!u^U8>#c zJ;dQGuSPOB;kF9sO|f$~{`h$&aB88`LP+f0U`2kfMLMs0lb@#Sw;hO$0mW=_D#iTG zv*59G`mYpRT&hUd?NJo;M@-1xWAnnENJOIl=qAFVJ?{#nmJaDoy) zbTr-E++_27oGeryyDH5dl#lLwW`w-Nm47O*Y7zc(Bs-m23^v`=cIVQ;Fs&7U&J72#jQiM}IKR@%+a<^(A#9 z@gu=!r1<>XjPgH1@0#9-{QTp3Le;$588w>bz3lGVt3{EQJP)1wRK+#iJ1ap+zcB;% z#bdQPh1g7;;R649X9j3o`llf|%BJUGu&{I!gMO#Fdwa=GsyQGqQna>t`R%CRw=1r` zP8X-u+NatDBL2J{#4DlX-uqA9dN?xRf(8>MbbCWJ&Hievs{e1MdA0iv? zrajf?9?SUsG_@oD-l7=wHbQ?0~Mc5K+1Ch;sZQG;vjL97!*c zdSgs8+cVg+8HZu-T4`Cv9hCNp#wGt$j6A@se6D4Hb6wHCobN-)6^tU^NXuLe#&FNv zk*kz=qfLGR#3NV_JSE1oPRa^ZQ&VIb03HZk!-Wp6*2#5)U87uzBOYXC=SZp$QZ1XrMkKU?%bpLNl&1&y1{C8hMkekVjc=G8t7iz$xAlh`&N1g*510W4ZQrxU)9kYBEo$=-D0#pG)0$ zqfuewiTo|=)%Eor@DKlwgNVP+Jdh#u8&U8dFXHb7Nr!f?B6r#PQIa+bjA5zEtwWA# zW#O##sN@ka8Qoyy(laodzK%w$yG1)%QeBrHh5Hf<~2apaE;&yt)!S2W5lB zkF|iPRv&1tRMS`4%EbwQ-FXB&g@?la1u!2k7W|7z`L3h5||mT}Cgbl++7!&DbG zmGrOzy)Gh=Y5zE|dYr#hw!}U1kn3M}&tHSqbyUiCP-bfYtb8f(OQL5p6bn@LEx|y` z8c6Zm?2E3gZ3QwO1Ak%tnMz)6%&Or)!q%C0-DYY)-Kr2gj_G#KTXmZY&@<4p+8w?& zB2I5JIUX4Zafu+*$xLRmRkB-gTy}*s=A{-Z9dxy}iJUgCCwscPEv|r4de%cKI&AXc zeH8<}Z&i5rbS#OoVudLZU0&Ek;Cs~ym`U^TdR>uEb5us%DhgT+a7@p>9!>Op){91Y zwffq2OKUb?HfpeXJ6KtUq>LkKu!e#<`H694NOc)34Dzhi0N3DDQZ(+3W4RjzVTU`& zC8)kM;bSsiYFQ-HKlH)a6H)nC*+pfVM`C>Ig220cmS5yIMpW{d+(3@Tg=Q*4w}S}x zp{MlXNtUIp5jB6Qc&7wImLU917e+RbZ(Zb2anysFxAM({G{q0^N?h;n z!jhBef!F%j!G85`#l)`y)Dl*6h{VoIsf~ZViZ9rNrFR4BpR0hKD$)~`_XjK9&sI zxyi0~_J*3TvZTT%HXA)C*w`jbdkU^Ouk?2SOT6(oRrIdVxpN?mODX~1L#fqEkAvIJ zRHTWml6yP{J)VQ7f2S>SF~GA{W}vv}u-UMO_zh0mS=I*{gqCbfD zWeNHZ9aJw&!vX&^dt7094x}VLn>GuHJHcF520bMkb3V8^5!rL2>0)iq>vNE>s@2Nn zRV>J3Whvf_22ohGV81d^!Nn;7n^~Z3+#e3|SD;{ofR3@;9yR*cNm0Q>NVv!Vewi6+ zOIOgSYPs>6hZXSFMBjdbRGE5R3aDu}`eIi~wWIa23!S(@lkpqRoK_!HyCbT!5Cygm zLru3pj~+ZyHvley?9_}z`P)dUQnJPM;k!p77>!shs+M1DTKC*^ahUo_5lbjyJ)eY% z+V^@-8FTq*M!eqr(nepx9N^=%W>t7qKya;|KmNx=M& ztambm{7~G@&+HVsz$N?i`YC_3;IpWMs39RJVr84~3ktl?J9ae+yfNc^;`mNbb`LQc zVbaV&6t_leHNapfdLw&ejaU~Z1O2gjo~-Sw(2in}2GRROtOJptCu>k%5Y5#W$C4-i zP?$yS+_VnpCv_5o!6c&XfJWKjdkRi#oqKo9gUWz^6D~ii{O@-^!a%{} z+c`822w6FNms0vkp*oE13#kqzd}=|YYs$vY)O`t|88^}RX|eT$#kfi#P_^Qt-xFo^ zijT43MnK8a;-S<-ezF7dA)Xa7!`LDFKzz;220aO3^X@A6}y zBP6J+T>k`By{?`}>eiZLP$sWEcfLM*oHxn#ex4X#wI_ zL_6=Z0!O+e%bu5JB<&uokkkQ~8_K9OUTb&Yk(I95=yC+QV7j1Dxk*8S)aqK^7yk*xI^)h%Qw?nD$9=9q3`5wP@x8tyOu*1gsMB znIa*?sUZ*QL9kZ}l0f*t8+Mea)w)&=7P~2>z)@Q=Eg&#cWxHYh_#%F&jrJ{L(*_sJ z!gXEjBR@@#C9*?t1~bK%y&Gyas%&SKPKvYy26=mm_VvO;I*U@v(J4bK)IPwOH?+U_ zO76nZR68c3$Tzg@qtZz7c5e8^W7-8;89H3KNCh@OO`R4ZD$*n5vi;btB>f4AueBVG)#e}HyQuEytv#@h`pq7& zsUe&@^I>2Zo7c(i)F|wuLEh`#57CfN_HAzzt^R7)*CR*)IJxps>iY_KkZ)G6vG$`9 z{5xVr>yigz0WVCbhb#gdy@LtW7X~4kPcE^dfh7@>#9qil~)9GVB80?E*neKe_bzY%2hESjk=g27?3HU60f94By_^Z|141RbQr zWuAw#Bjq#HxRK^C3U&f((6-e9WEzJz(^qVhXfxiDX=caXtAMS)?uqKKsP6obcQ6cmLoai8Fkaeh9A4~pL2E2T*?5j3+s`L6&GkY1hZ zPlOf)Z8I)Q2K(pIR0nf<#;*QQm1kShT`B~g0Cfa+Td9M^IxLn%ldeoo>ru%{vcs2k z25P_pE$w?Oi~T1nhoL$k_~PeFLXbbxkUn>TC;STU6OaK3#vQ5NGvxe^+59kmS@+v@ z)$eW(?iT-k!($P~a3`|v{hc?cmMGq~5<3&+tgZ>bP$%As5HZM{z`Zt=S_=S%6wG~M z9xt-)C;7hu79w$UbhPNmcrd**9UotBN9~40RxJG&>13j!nE^@_g}*)Li;%)4b+c8k z#`kCk3(JL*HESN#J|FDT62t@>+|xP8Mz#%uvUSLE9nYDOtzKF!el3HZ@V_k=QUfb6 ze?E-YB$^8bxlz{b&WAFiQ5+MX#XIj1(Ae1M%So0Zg$0lW-6}2XPGATAlB?zYoQc%} zFV0W_7nxz5ZTB3*LU$p@dO=K5SVU481&qNi4-7Yt0mxsM>ilUHojbr(NhR`*G>ytA zr~k%a`l9EfSyJ?pI{U+X0g+vU&jBn>ZRcpogqXiRh{o>Kft8W8_T-yr613u9nc^+40rX)$QUCxmcbnDJ&3neJ?-HpVw6BsYv?~UdD35jL z=;r%i94t!$kt0HcLYz8Mzex5bx3lyFmlhhKsZ~!dEosXNSgreo9DbtoixT8J-aRO% zO5=W*BfO`GJTFpZ&~)vbWY>i1u*``So=9L?1roD48VL=uxot z6Pd-k2$Lvp5&H(2k+5sAu&U)fPye|;^ASkihHV#+U|uT@^1Z7e#KAhriMO!A#gYo1 zkmGT?aiOBvWK?cTa^KN_1oW#!fXo8=#lRE<$E^h%7q?^m_;DxcItnOX zz9g){V~ydZRWWnRMSt4YMEC6t_itZKzP-QGzd)+(dM0(`zfcq)VKQ%bmni=&wiW~l zGiZ(swU)VBw+6|sPi8theBgzA1^>>kkmSN)+&c5M)knM1YDC=OM(0Xv=42OfK97PN z^LaO3t-a;(NSFzBl+OV&iR(TYd+R&oEfB3Eyd_4?KB3$uo!5Q2za|IpPoJy-MXeaD z3Wf}|d?rPUt@FGkG+beXCYvgmchvWH$3Uq+i}1ctg%0!H&XveW5dN7xQ8|1q45@;) z^8=b-x?@Ga-YxK3oF*H~Qdm;H^i+06=)K=gR(_#$ShCjzBN?)%PvG zULo82HmngGvIubFT+fA~r`6BFc|&Rq@NuaC31~ef98!KZzbWkk;B#)ETwN($A5(mz=R$xt`TYZCnDdb^ziXAF#wRawp|+d&s(W|X z77C9iov4n&&Rv?q|CaVCEK?OM=9ne%&`Xdn&p zjUz)CXcN^%xtO%!Y?(L;N77Yaw7xFVRgL*^wODRa8NbVc=1|Yyus)c6c)nWY-gls! z^nML}SOjOumt&`=+fznlS0`F;Ar`BM>sE1w*v6GJ-8r{F@Q1gmw5gsAQ{Hf@db#^pEuDvN1TGX4V~ zMlmW0CGL`dX2=_$!^A#yLjCkPQ4b>&2nX2Us|f}~`!>ALV_$FK+mXEoEumX$)i}^S zlQUhl@3$ABA%2cc0e~eFfI7iY)Ee4T?rRF=?Sv|81=WeN?X{H%ZoyazRmEw2EL?L} z!tdNgr0@4wf33Ye)8<6wb~e{kvdX2M&vHS@B2tQG4E9!8mk~oItt)(%DFl~xJcRpq zI*N}o>F}fLT+~m#mKtkduzHoP*CgY!imP7JUuqEHBUX4jhtU)!oK5qH+)iHI*;28A zHVew%qAg^P=g!~J8t)$g(Lg#EsS_rGD*p-|2*_Kz1^wsd?voUirSb7M3XN|Ct~pJw z`ihmpKvhE*@$kKx>%H>8julaBe9SXLFB9iURJl8TmCraEhvGkCrqfxbWi3!22UZ&E zVwGp&U}}r8i`UMV=TxeT8aU@sOY9k&c>ZKuHOHm}ch+?0H?MBtE+a-5kqSU%6_&IU zr9xd>IX$*PKiGG&f!F3|9S}2&FsPpGmylbPMBtZEL`o>~S@Kq1d8SMhkHuTuuxo&6 zA)QIjW`L4JgCCWUmbm}<%P^i7I|O75DVZ8S%7qOt=#H-Q-N06Y(Y&p3b&2|6g|xD< z8l!R}b;_&0WgWL~e|%h3SBYooImd6mq<@JXdh;kTW8FEYOoQ{OL5>P zXOLjt*sgDnL0Xyj8@s0T0>+75zqKIxyW}>3Iqy#P=UR58Dw-eD&NuEz&?5HBWhZanr{83oz5`oSgnsV7IyN4OWr@~Q~{PB!QawZN6RKWK3nt&s--7q`IG-t zVx>8Nh!?|jzP_|EJgQxQUJF_nryvyyrq<2&e4*la=C@e-AUg*$3-yC-cj`KoBl^4- z)M}b((loESvh&?ka>CXBQ*bwU3mepnR5vo5p)&7X^oY&NA16~`hXO3YZujGtUoA3K z8yE46A1&k7Zu}_hb7k3G$7@J?7p2$4vaLjsKB5nm`zJ~NvGwD;1ubOE4vj_XX80e| z-T~UniZ*^6T3TJ4=NHjqzW$}OfpPJ-PKTLZ(soh;F}#(W=VcqmJ7`SGE7j+TNSvk} ze3y63Bk;o9)t5HVO7=T;US(=_jBPP!*Upeou>l$iR1}RK%ZPqZzV&@4Un-+9^XRM> zWy&gMg>~x15-r*T&(2&i-y*SR9oXB${@ZiIpOQ&}%mf?nhUOj;b5QgC{ERR^2$TF6 z?_?f3=nxM763T2BxNXPO4&FET-Ond%c3A7@y5ozzgHtAuX`3TxZifP@VO%V*d0!Y zvx&_mC-7qdrBe_rrS{&+{8SByPGS4H2JJ_~qobHe+sJ9umLPeC?7HQ~Z(C+V6L$8Pg_HNEDg z&^i{APMKg5^Y`^XabR8xE!FEIs=reVJD_o?7))LDCk@HCI?T#z)yC6%2=Y{Fn?JU^ z;~S5%5oJ;eDJv^`Cm;dC%-RlxJE}^gaA1VR{avJKP~m`fh8zC#z89)`|j&%y;g$3yai6O50j$XYOkIHQiw?+^R#7*@Tn*ebLum_H6Y+bGGc{v~yWt?z{3H)Gj5sU(4Z5f4Drh@DqgV!BApUa4^B-n*U>U2ab(z~9n_&;h3>S2LMk2~luo{yd1YgrSpJ&V#&?Go|JayoW| zU^15k!Jo^?SYnyiB-Lz?%^{^B~$G}v)-`m&+M(Iplf}b)+v#^-B#vg<8aOl zzq((`XBqUK&*z%EljUf;GR0q#x4jd=Y~7K-eaEI=(3J06P#d51N_qcN1Ca;0K+EN{ z^UdM!9}3?%11N?ZrdcFGVo_YZ_STSw)ojCbUzyYIs8p|Ob`xJebWHvn2@+H0H-5LY zN8ii-|I_!vRWMnceeu&}SYq+$9yLeZf$(_AI0`#p{cQli^<4h$6kT&PpK^)qB;m5F zs-p3b4aj-C*>=It(Fp&Up8_2i<&8n0^b2WhhCjbB_qe_L8RqvvNhZWkbT+%LGs`>RR-!|L{yHDCfDd0|6(>mN@k zG=A!fEx?8Z{P=KD#qp}?q@Y{1EBEbUtOwtDd^Cf(T&+d%Hu zr9a*hy$D`~(j()~&(H#Nz$f#RXYaj+Dkuct>*s*c+`YOd>2p=djJ>$Af*(E&TQdL_-9mtLBQ^?pEKz;}w+^Tt^jy zWwF1~kZjWy(|TJ>l`5t40qPaQodk=)3TudskOFaKe54>iL1etv39dWq z6#3P=T*;x0CG7{~pWoi6EiE!fk`U*tQjaZJH+& zhFbUQQFa&Jt-3X?ngE%1R7TK!y?;EmW86C3sJ4A^_lAg8M#=qkmyTck6Way?CuY~` zRa|+~W=YJYs)i-tR>w_f=7~Y;SrtvKTTiuF)OW`=G$vCXp&$^)_MCQos*#*AMoD-3 z;zgO!=1+jrx~6NBCq6x&7(3P|%N{N{Cs#srKh&;IN1kN7=EbP_dK43OqwC!S^JxM# zArt6NK?n;v(VG6$UgwzXI`x_+mj$h*gF zfq(-w`&-J{mcBqM%BqHJ5e|-)zMAX^elfluh)5qc9~6LJaYx16urwMGeD>BeJWh)N z{)j6Ny!^W(GSPaSWC z@OZcrEX3&9&&&Fy8!uyIzSjvVT;jV>FyBv6Hcg`1q4uFj@;iE}n}!0vmvx2({Xt^$ zQG$20U4^=>2|kmu;MV|^fHEpOnd&%oV)ap2)7Am^d&}pOs;Z4k?OL#|^$}f#?P;3D zA)u+TF<0f0=~JS*OG0hdF=%%lRom(F`4g~$^82j=3hGkLVZaM2tgHECkIUz4%5Tc! z$`kqX4QvZ$p{sBCuMz_nvebXWfMxZ%w|?;#)#3>3W}<=#9zQa7R6VAFpJ+%rHJQ3A zXrkGwx}l^O!)Sh!LG~@bYG`V@b^~Jg24*QHM>=Yg@@;3j zTcLR@FzITU{gTW~%56CjK8)Iu$l1PDA(N-9nR9M-aJbqxo;&G~w}tvBEJLCrl>4@m z1T2KEa0G*7SZZ$Mqp=pgc^1W9v~eClcuw#xly~G4kAhMS?|v@G_?3#&Gs<_-Bzsta66hSa5lJmUd;J47@m+{eYUvO`In_R61$ag1CeQk z!yzYyk#_3ht*2+}Z({nY*-qD@PHa^rf^K3u=^ zbP>tLd8bwBz8zP<{G|ap;+~XENOvO0RfE=(>N@M}FQp6Rx|-vCx2kkZr*OArX5 z9R;KIdQoa8q^W1kmVK*h|K5Ms9`r z?3}Ew!;W-Q^BrvQU`;JYc=NIgha?dQxx2M~JF?CRrjFZB_9eq~e4f_O8$ptu4>kpI z`p}cbOUCi$gUJY_`U!zlSK2;hspNc|J!m-N$WjYUztcc-y24K6j8Nx%{spzV$!Crq z`c}k1gTYd!>==}N{b24TlzgY*<=0aQ$9*aRJDU5%&JJ{&2t!@;ggsL6h8i;1roH9R z6_5RvYw#ru6zuho;(b{$RRJfOT~;6TOPSGw40NB5?teRvgj3+<4kIT%y|<%368?B1 zL_N?7974c~_4klj^a7G$!1|mk;5N}-4jgGjZ4G>@Tsk{<$ zi*7DE56yF}S7G88O>Z(PXKGvqB-#3CJse5jkoEJ>UQ8_xuLTE2Jn#OcNFG66`!&c{ z66MyU;@DZ~wHSk|4+8zwJ+Ksz3L`ax^xOXx>^5yl*9#AFXJdH1#4*Y#8@-OG`!807o7K9+w?Fn^wUb__DA(NSUV z)rGdl!eQYYPBHn+#7qi-!<>UDJ^+I>G(~XrUwEqImeatlRi@%EwBy+NguUcrwobKA zwU!P*8ga}-s#g@jiv7re?neZ`bQosTL*tH{{10)kEl;BM8*c@3G4lev3j|Jxwban{@);_`v} zWlKvYKtQ$lr>Sl#~@vP|7UU0RHvZ5EVLA!`-{HF(BzF&V|pl+)^ zmZsdrjK$u`YcL7hyULuERFic~t~IWpj*v8%4{A<UH68M>(QtQ!Jl-zI(V?XViK2^|Xt2kge!d?G&-%sx`XC;~IT!&_G zo1yNU=@J}TXiQgkj>{qH)9awnV(1~hU-SnYt@y5Li)kf<6w@ z;<}HJXF5x7c=;HnVUQxBlGCTt)H|4GMu;=CJ2_HuP0Tx@CFh_&hE^;M^TGNTneHPD zLofe$e!^Lg%HEy7t6t_}S|wlZyudsnU6OvMr_`B7D2?7OGvRj_udeF0ajUv~%yp{g za2wR7iKOyS3bV3Z+0{|`&qzuY2jT|T`gK-3Sp~h91r1s=`Z6|_$whqiniP|&?crI; z>h)#-4XKU_E#;7msN^c5QcYaEkWcNHeaJi{wINQ$UU`*Heawo@f0Pevw9zhsdcTM`4{M-1Ry% zVD})5c$9!kL7v5+7I5j5aaITcfAQIQ8Za4@@xy6(NZ}#7W~wsb>U3cX=@oI=imLR& zGr{z9)wo^K9b{;wGp)yFzczh~q%TeO08Edmn}0?n7o*^xV)Xhy#puVL|G-yK!kn8y zYk@y>tNHfUJ@viwo<;m@8O&!vp57Gr78uAzpkc-`OXq@tBPw;aqc&)IeGzzjm1Iq) z^IMSGV=dBh3tvD1*cArIXL#)j6v`aOoU*oC$1xE%3O5&sDWL1>^q8-F3o@#mhj5Su zhLG*VJ<`u*2-qp}i%(418cJzL#%LMfKc9MjA41rM3^3}aB!+Ks-a6QyO_3)CctdV` zsv>iO6Kp!>+z+4-+_FP5i1a{)6}viSm95`Q?gZ%&W03X{2+g%9IIBqI(+X@SzFuS@ zn2OQl(79wA3)#d&C({3OmWqaz*+mk zabO8TK_u1)Z8{xTi*#R0xw&paw-F2h=b?OKTB-pGA}G{yRj#=b7%N^+Jyqi>yTeVZTtr}i)=zjEWT&38T9uqEh|NCwC_3TUt2p?FC)|B`WV)iuW7ITt3pdCOQuS@EM__V)Ne)v=%;5- z1Ao_au14=}4PZEOKflDxds(JVqZ`>?AqhAMO3E0<_T%6d;a5=BIu2Y!u8VCkU^a^Y z#yx5OQRYG6pO4#+!q1ZycQvJRMi|zN3ASh z5cx3+xCfEYZKvssLQ$eT!#BeX=Il~P+*7e|>Q1WLw>vz=KQ~nJ1u-&rL4HwsGrT+B zB-}-#6w#@IF8SuT^!*>ZoXt`-ZC`y1b+_%RnJl)sIpg@i-Y|7-tAyxPv|elXhIDE* z1GEB`;7#4xFoUmh$}uOnmD7ZX(PTXxbZ}N9VIqEm&UanzxMb4D*R ze;Yr^UAF>EwnP}4rW%@F@9Qv#QC=a5Rn{=FG=c+4RGD$gxEf^b_0w6Eb6$^ObS3U* zM;FE6m}FLL7^1+s@{|xVr7%!G+7cTlzQM4RL$_$ac&=sKk z1Q6GN#D60YuYdIyoPUv+XWp-d^k?f3g*di-Nxz%oitSh%mCT!K&D@Uz(M< zMO@v?!@eoaR8km6(*LZLeWN=sQ>VS@JJOw}-O7TU%jT~x#cB^)1TCE9qrWTuc+E4~ z{WX724S_R|oXxp2qfiT;E+LW0i(voq!Cpgv3WN#+7vH+gJ+$~evB)M~teTK7yd2?Vfn+lcxV1UV<7zsYn!H&56fumJqT$qt+U zmq$_$r4$1kgH&z2fK#j*+Je|uLlb!QQ3#p~$01^P(i2iE5^zCa6ghMJv~2SL^D21QgZ_=cx9@qK}}(K}JWA zEcwQaiV{~sW>Iix=olyG=F9f=^&aYMS;S>5$|Z5~%4cPxuD8;=Ad-~{-F^_f z+mj_Bnbd)-$gy80xrk9I)W77%+8K%SKa^0$`o_D;vw>ARn0GG3AFV;Afy_vS@<M#b5W&G?VEmbKJ1w+pm)Dp|IBiu#FINIQf zAq%@V1_(K+=8xonGxw$LQZq8va<}}x09YXljspwY*fON{i|*7ju}U$emw62?x4j9K z(o4$yrY@_OQG0y!frh?F-7Z|mqDQ9~>Hm#GR^4()j+c_}U|ssWoTvlQ=@R{7e8 zXpimp1OT-O0^}Eib)bI)Fz2&gaofE6DxYeJZHkM6iq%AX^~M+WlBC$o2U{My^SkSc zWi?V&i_?J@<>v)9Dbo^v%w|;9sUmtpEfw2m_b=@_0P+rfcw>x$D>nmr{zbc4qTyEh z?5m+l>a@u)DxK!;mu@_=XwTb} zl98>}7wgtA@`MI_AG&^0C5ofbM{m~4iax&ZliHAcqSrz_&-82cZq<3Buz3>tx5=TU z15Dyl3s}R8MYan}OE11hBc5O+cH=xMT893zbnb5}r>1J_-#K+tDl^>=Db6&${DQF0 zD?O=$;w@0N+p_C=(q+r?szv0dS!dFJH#=JtxVBQ*&8I2$UR|N!!6p?N4TSiIh z!2unxv0>oQY5?+g#QA-Il*cGN8!y<#69LYBG0hse|8qyMf`efmlgGmq--oL`<%zDC z7dZ%jfdhc(k#QXgx9EWD?>h)tJu#{UFS35PJJhvXc4JoWUgXBO%<=L<%<&jBvMUAZ z$!til8UMm(BDBF>M}{QGU8o1eAq(XV+dyDIYck6Cho94qLTDe(w6sWR?MsSrdXfgF z#X)f1<6av-ckupFpM=-u;|`F>Aunzc?={-D?@+TblQGIhU0NOVsMG+BDcikOSo@8z zPoAA}ggrVTzwme)NWI3MQkb9U2w(Ozr`rLTND?NaQRtIwTy9QHH9#WwZE%dR0;+6% z^k)7DbV6tIH-RS)2dWMUgwf=r$XDGJ0I@ehcZqYWVg@nBxLU{B^cWoW-?jprut4kk zmNIIe!3_PwU@~LAxoSiMb~84@_~N}Y z?@!Pj#x)TH4b-^S)|MFmo~LIBAda<(VCBDG=#x`4$|gZ+7lKd~W7y+B2-=jGmoFT7Jhkb~dPflu7xH=`9gW!>kLnfJKcnG?%FhpZO{}OUJG(Ufj8Ju!| zc}yBe%`pI%?LBMTu$64Atxsyv*rrF6EX&y(eCLu9e>A_7*Rv*x{=SsdL^z+vtc?0c z?kOL`qr7pI=@T-GBHd2|ut{ftv{ta6Y^&C^2S)c#o95GB>d@HlFi7cKi8#!yMnyo< zkAw%Daiio!J%@LfG5y8dYI1-EyJm?3qQe1kV>u)XGQ%*+ia_s6v%Ng&0tXJ9BUm2+ zpcs~T`sKvH-bxNk#X^@!D6FWy9IS9F%>#|k4&e1AEN!Z)s%WUsh=?9hD5EWZ+$>bP z_bU19a$QcxQ_1ymmsKBZ@}#>lY={;YAqma7*7XlUmNYQNcZ-L-aA~T= zl(hj}9Am@G*MI-^%EOzNi0oo(&JYi}Z)PA?(YT{BZXs#`k4Zedb>G%rv0PBWVuM7G zDS?~jOneU5q<_dyC76iji#5?tX7PO#9|WE8g*f)7o2g5i-Z#YrG272KLU1ev_ zQpt;A5rXgRNBXyhfV>TPP}uir+etzO&hoRnY25?#@`cN#^aXDjluiC_*>s?i3o10P z*SGyyVA_0ACQCn5(^D^csKD%+lW(40lS0;qa2tI_!;0}dqmlLzi5sWnc0cRU>w;tu z1pox%3M?Uh2^f2j;f%cU8Ky1hCma39xEffik_9>hXh2-f6d<^fd>&l4Y~Pl^;yk`0 z)h z#z|eLb0+S_4Y0~PG<-p6D)3G17bsx{_%Ig!jkO@}h@}P6$IjVR*4M=ayfBG-&t(*q`mL3!}u%wD}okwEsc@$k-5tjSYwkHJra=5jzJByU)&V=~DaI1YoPI!4G>Q0yZlo ze`gdi>)zJi0g3xocB_ySm2TDXH{q8W6_}Lq*v#9@vy^6sN(Aa( zQatvhvX-<|4Q=Ytm%-DAp2_w82%;h2sr#8{p3wi|Gr+f1XKpW5sk&sC^r^5eCH(dkrGJ!PXe;ueKwpln@C&IdiL z!LKomRl=)-Hf5h{)PB)YK27%U+fCbAk3aPLKdF-QG|5AWV;hSA|Rl9g7Ejb(EHcnc#lKA5xU-Z_8X|I;H_oQSxAAwbY= zFT9948VL=8knER0!dGRW+F2TF&)dk{Scqjm^Z|OHqDA{vS?{|hr~{8+Je|R#iPS6x zGeKLf!QJ)cEDu)1#SeIFs45QxSh2t<7muSruYlybk#^5O9nIPJpQJC%mFXX!o=ID| z-0}zUd5++_Z1V53Id)PXKbo zu}K(`^%rI0_hp13q!3Y51Lu|1(>t}I*pVO)-0bxi+7O&|j{%7@-bOrU>(ECg!bQ_h zy$0!o;Wd+HweE5khsnXp%J2!ZBqYkfd3lOhQHyfmU+mXdq)FzP(>f5vkkpoMqMXDh ze+k(5KgKxNL^OwfBQZjN(caba*tRp8)ys|&@&JGFK^G=0PISovW@gU#B#P+0Z{l6l z2`oHLMB^`g$|(>q20pZ5E6qA=Vq<_-C56)^drcIhFCjiWDDtFer{DTcNS=~ntl3pb z6>z9HLBASx2eGP3%G^rM_=d%eEXFh2^GK!dx zQCHurAo*edXtTXrgErImMvL>dnU>Ugh1;fc{TN{sH&c_&f$rP$GvVLM&l!1XzXvOE zJ{BILj>rCikPEZDh4?{j#IDE@NVuckOe2o(N$@5aW@x-bfc^60MlE_qQHG)D-*^Rwt{|!>`q9bV@^1J|Gkh{78 z^1W?*$oEw9?}3gaojC~f?CD(?Hu{l6$AFB9EZZh#pk&ic|+_+;|SlqdVKEthXnD{j}RY?XT5l2Z~)n1bFwGIHM?MJNWL z6idovRyfQA zAHN#zo4$4+6b83rS_JedeVcH6lkhduZh{+ykWxd&Z5)CtiJW^B>)fwi8@`Stp()vK zFe1)f5QTW22b-NW=uy@>4)->697kR2=q{X>*-Mb3PQiJ-f?}C}TfR_35x1KM(wLg^ z76=4F60`idlh+^{pkzcX-?$!P>S?Dg{rve;3eeO9n0`6!rW6o3E3I+{-Ke3y8<1&= zoF;(!U;S(&q}u3xetrgKs97F|tRDBCJR?r0MHE8D{fRL_a@0W1_9qq=#No0`3X;7l z=I3w^ZBr<^NB9wb@m_cKDp>-BCyNEqtXoRk<>s4?K;E7;nqzlgz-?80cB~P0Kv87| z<16yYhZPa=qK%I-_Cyg4h#&{!$st+9Nv#{UWr*4d4p1Kjp`-Ot@za*Z5AK^mq)phc zL3qpNMEp>A&dftpGxppuqdV0!7Wo3Liwu_H4=pO&s;X(yq%uUDri+fv6)_sFr>AJ) zmWr(Ec?!29-H6jeyTcf1YknRNv-8q&g{cm)8s$Nb))=N&qtUmT z+rtthr>`V>`H4Bux_x;Pb<4{y2Or(RE1l;$Bj6wNWbD0JP{#$8%4`_SVh!!sSZ4^o zTj+t433LyjZtz}Y@3i8oBSNw4{t(wSCOokoghi%HB@7WG%P!xUs9XYG-TF9n9KuD; zrrXrSsq38{OjGL6cWNv;gF59u_z`{$iTZe&q98T5mmmFO+JA|#5@mMB^udVh)%#F2 z-8e|mm*{egWd8>V=4JZwa;ZB&JyAONE-3Xw8)v7!k%1=Z1p!H^j@ULukqU;As*dZq zstqba=F+NmevvfuG*{tXQm_z1ETdiTJFGG+9Rf)*jLY*9YV$!aw#>n<(vZ}|^76`6 zk-^hvlztVI;i%NWg>C0yH?k~teklg=te!yxBdA=~9~F#9gq>>nC4;IbP4bR@BQ4IBgN??5p+8INTlz^z3U8}pCM5od zF*1{iv><=r*XUsK-ZQcC@@;B1oAPto*$H3Xtk<4s^cZNg{46!&KsF`2lsM_av1=i< zG?b9CllRu9x=&l0JF14NM>>)A6M5)WkdsKiB>>(pLQbvlgn&52taWbtbr#Fay1>Ve z3p1<=^ccJ=n;O#6N@$J8dTr5!c_M_$Q)atN&$!Y~zt!Q)iVkz7i_H5k#f$VWul-ba z{dvSFVB}L!+TIRzh;2dsm0A~UUrpVyNh5pf@PGc|b9_+SEc~cb+^=x^HRR7v1m-~0 ztp3c`Z&rJS_J983=Vy?ky=FPe1c(#%nKjU3C3=M~a5grcSp9AgL;U^s)|P9EW@&*5 zkKI%|O|R7+&W2nZ7j)_RHP^pMpFlAkMN1n@)qR6{UL}fHdln%IsSc8=*zg6{9<-Z) zp3B41XE7qgABFlwTpM13Q0C{jp3Ip= zA2khNvli2``1k~8P6LRf6IV!XzaZLwDAc*grF1%eQ$K+sJgUU0-)8s??ZW}rT^D!N)(1byn9(2SH$H9@h|BwPB0PEwj>6_%n;`BVi zOvd#vWRez(n!o`Ap{O+AaD=&%9iM71DF_feP5Qv}>pRa0T8={|i=!uy(8ix2!VAoV zui*W3T^l>Zb~oXy1(YBF7aHJ=D{k!NOTjLKY0B*3bPO|ehhZ+xNSHRn@r9zF3HJ~V zSxx3qs@=dodLo3|69DfLiMZuuIEgD5QQ@S3`I(}F@XIJgz&a*#SLcC;*$R({qRl}5cB+k89x*l_&PY#kr^nSSL_vQH<6w<>iS*` z7xG1OhQM7s#pCaPDol>Vlau4KX+l+Pcq`_=u$ z;u%5@ns)T9XFuRIKLfWwa`dS&@?|H|4+%|iLaGNDpM5&f?`bKQ(>;f-A9dmAAjCa1 zmE|PW5&$5|&lc?;FBJoi&YJ!mGpWaduXY_oqfaj>CEpW)QHyW&oLw=w7w7(abOi8d z_+B0{B@_3d>Ob$_{>{VJKm)r(%5~%JLt@CoJB=qD4Fu7@z{dPfO5>rw z0ha(@`*8BZeI3yMTuShB*hc8cPs!}D1OL53NLL_)Aou_KJwPJ3;gu`@A%G5jN($nT zW%4J#b_mEibc@2&FOZ=1|La5Wd@SEOA%o_He9~(c11vCcHzk-O`vh{a!Vv^833`Lv zN?IvXLzSUxJdQw;0rsT&tK{8(V%EU?)3qi4%=2p|Lu_a zg8U=FC2fq0LT!t&w?oq7&=U}{B+!$}*SgrVtH7dU9yV_~rUyC?#MyHr)U0r&e*!<^ zryqDp5L?Rd;~L!x^*Lhf-UV9eovJfPow}*M#(eI#;KrpQNwEoZQc~CSj`3tKZC+mW$0ohdPNInO+-DqG>eFa%A zjc|fBJ*12+6d1jzQ2y~RIS9dkQ|4GP&A+1)Pf-bKtenjq|FE_&k_h0Z{5N&yv>Vno zHhR#6SLckM--lfu2?qmtCDZ2lVLwkPI1Kr0M(E7p{G*pmt1Gd*%`Wyk;H3^jk_tc# z&Q%PoSt68<0r$2a^&;;NQiDo|P+)B+pM7AD_^uO_v$W)nTOW!H`quT|yr8;dPKAAS z$^Kl zg?7KkuITEx>UNm*IYxeF=kV{^`I++ePtdwbZBWBG45JZ=;5GX0sQe8`IlFyo6=QcC z7xkz*2ED-r`Tc{e+e%+aKuVovev6QRrxxt{clvbTC&e+*;7z)PV*dAJy9Qn?&WF>jEzEVH~gC zY%B5j5g@cP5uCmjfLE*jY}vWAyq4-ML`phDcpa@qoS7xi4iU@8eyD@guLgkqQQ)aW zf@xPTEkS~-209#Wo-`N{WTb5ZV{r^>a0b~!y_aEbR^$iU^w==KK`u1GUOD6W`AQRrlXQ%^3;nt;v&glzry4moD zU7%szyO>L)7@jMrG5}>YuUJ7Ojx~6h8jGw(P>`7fl;op97wIMIx*VJjm;}f8 z;V_=Q@k(tG{L-mU;sdjurK+OQZu(Tko< zav|J-W1p==6cLpkF9d2L7bMnJl^%+mMY91XIUuw%q?7R(7Cy)b)gQ0o*Gp?ifk+t!WogRiRf1eP`D4xlK3 z;IUJM1IpF^Bncy_yl~KrsOcMQ>H505K*Z(02zHS=*e&#c-(7^=P@F|uB?6$QV8B?V z0BDSQEJrKzj4|3f<#O8vpOSzPh_xt_oK8#te~<|Gh|01zLi-`ZZ`y>meANUOxL(*^w#h03_(5*;;{7Pc9Kg z0ZBOAVkpdH`cw^krg*yAs(<@c(A=Hxf_ze1K8Lvw*fgc#QQqqM9UQ@-`;>V+90%u+ z%uqh$h&`P>$_LLa^7P+gbqS)~dC4H>RFvzrBo|II&r3^8h<3+BqTLak2>@u?n<<}E zAsNxCi#xCxQr2#c+0PHK0#0o=4wN?D*t7^BMt$puqz@=VJk4*3!zKLcl)=c+*_UE4 z55A)Y_>KZr1no!VNDg?fEeHs8zy0+6-Pcz{!0f|zSrwho%P0T1;z<$!)KV~fY1qS& zK%>ay+^sW~$x`XP!DI4yEfeY&pJTHE=03JTlm*vU`ga`f8(RWoL2SAZ1>&aQ+7?G#;wPureGNh2}Vr~RrR=zx>8{!`8w>$vO}m2x30`8MZOOj=VyLwD zqc`bgU8dqUl-Fzs&8-i=tP35Zi)XA%J&lD~VWRo{c!7r|{aTMWq;DwvoY*mJ9bhyY z_cs(nhYuF8ZaU=|gcT^Ko;5T3-Z0On!rAMT#~7yUJ(dOyhV6W<@VpB*7nfS-W$owi zb>be+%%v5$FJuF`%fyaCMaN~sj%zaK^UpBN&n8{Cu4;BUZLV_zlZU^>tJ7ZHRK={M z`YpN#sffL-v(R8d5PE0H%ay=-|NI_L=ssWPcE9Tx4=Vwea^{sKr;P{$AD-W*_~bXK z!%<79^A*!KzgV`Cxj;m$)tRUcixe&r3pk6B7oN7*7w#cnHcf^@$FxPkJbxs_=N-#iIk*H_JDdV?Iul8PED6`^>L^H9-! zP;i@q#cB813lp>}uPQ3PHH-f>DAX&bJb7wf-}k3emt2azd!eJu;gZ;DT6YCKAMR}v^vD<*Pw`Qj>`S-axoI_nHnyAWOX61tDzu0m_ z3H=)R1z;R$`Kf?pEsr9iZ}!Tqwhkn2-Dx(zD*=d1u?9DgUm)`PI>cIz(3x;Xi!^l3 z$QcYlaJEQi?=$Ajr}7|YSjIS(IJS`mc^1V}LJvP;A#n-{i!{VWlw@RNoNgOS8jUt^ zZrk2H2xIZdWAOzu?){PHurM`)r*==F1v^5{L1ns)oO}eCnTs3KS&d`{Q5TVF1qft4 zTqX@-nh4^2!|#M%n(J4Gu*)%w)`{RbMCm(ek7x`)vLf1;tjl3?0}%2A$V^hr?X!Zq zVm)e9^G@Wk#)CHqxAwfprThUsc%N*~-6Yj50lV7!L z%};|Svkpchw$ONPWP&xpI75DVyAyU60H>Lp`aqDkUN!qFA{ql!pEh zr7D*F=*;&Io9#M{suz?xx^+IK^0y|2-dw6T#uw8bk$d6TnW3gKbI~_FAC7M5#v)P8 z=#6CAHh;;c85^g#S+!`?TYwF?df002K-I_co5$YFb&gjrdboXZyesruWH7A=|o+QF%r)U(TZLVn9wu+fw)8_ojtXYzGKUV<@e*@a@OJLMWfI+d(4Tv(9 zE`~|d+jmOZ8TpY^8NGGAk6AW&`*W*&=n?yb=N7cknM~amF8x=Ts(por1jy{{bwDAW zE9U18l&)AfK5}H0tiW>n6Z-r?^^ygUf?nUIA+saT6W`Ul0}IRizIlIF>s`nK4Fm_# zhhr}PA7S4ePj&zPUs5U!R76BZ%FIYMWm9CwF^XeFl)X~gA=y$^R%T@$dsJj)Wn|A| z?@0FVdbw}o^IgAx?z@NN9PjgfU$5(WUe8f~pZ2Ri?Qq%AhsWGR#&R&^VHf+$Q!&pb z_cou|s@m>-S9S@hbOZ_h@+~!z{CQ-vOa+Mdgm_VXGXh$n<@y)U2v?#7Ipi*<;P!Sa zqNNVm+esAq8)WS;{!1D1F8X{vp2)wy*OpdJSHB2W@C#E6R-BS{aieLs!!i@RrNr)Z zIy)ZS`|2#s`Nc)#Ic^&wy!mQ$K7RjPaaHT#pQitORXp`93IDVLnB;Vxe$snN4^#zT z1crbF;qTLfOs6PTwN12PwN@+jwMI!fY!Ru&UkIuqYxMzj>0otu=k2S@Ft0%rVg#<$ zHiX+K#-cT?tFS50Mu$^&pS%7aazfPs?w}tx8M?W0b;gW`>YIZ2V;L+ilLT6K@xXz< z#Qu&psR5mi4fdWEY_>QrhonhDxBAmwqNIPBoQ-9_dF`^tlD7L2y1jH8Z!r=X$ffqziRzt~x*=Ccc@?DYN7*mGnt=82{Mr11n^qg-?at^&i=A4*I#9 za3?-y|BAUn9nNb74OVaKfqSC%<2Q!rEMO9={^ON`gnzP}oj52hR<=gaeIE0hF=1IM zl6=lVEinCi)w;JuV3RJfiy%Q3KGScwqxW67T&e#qy5&rdg>K<$?ve~7zbdjFrn-FY z1!OAU-am4vP*Oo~Pb$c(8j`hmhHf_Q|5+acLtV{@+FYGn5CU3~_k+1JgP;cTBEV9m ze=y*UKfkpL=YmuJ{39BzY9fIj|KbZVbneB9Wmtn)7s(^Q5>bR8k)fNg3pH*W9Siq= zYwM)c3{iJN4ua53C$~;I!ZvvFC#Qe(F3PP_VHd%HYrBNF-{%qGHxEBqIs*4-1@lvg z#jTZRezek7zl)C+hBf~-~2uyQ$IQP03J^%3k& zSfk}f-29M-2z#UdGK&Sn{O~EcV2q4pcA?&d(<;qM&^*X|cC(2Hk1vJ4UcmYXSaG2S#=l{WL!=Q z&7Rxg{ccGB564-DbblEaV{UGLE+^?R(rf*TtjUQt6S`;3!_jH7;$h+Y>gi-PRtDG>bab~vuOW5D(Eiq&|})j&JSHN(%;*qXb z?mLJwc^4Ls&nwBhY&Q0W=QFGF*}uz8joli9k_%C*^*}bbYB6>|(!zYrbQKQ*VaOR49DPB$~!-Nv|Qx z1=F}!ab|%0#>fCEo3C(U)cLK;yxDdA+kg7JAT;3fzXyHXhSMK{gy7>M-IH(H6AZ~U zf}{eG|I)+JKM)UvUb^gnaMKpH=zfyy#Y_KtVE%7ivX`P9Tm$p9KE>*MKE8TL>3On( z_*bj)ZF`7^;wq^99G#SK+N%5{^gch2oy=96eq01<0N9(Iro3bsit+;lsE9eP$!&e( zE_XHPGeQzYls3;vl5>`X25HRwaQ6YW#bN#X$0g>_$1||x42I2_KGJNT*DuURE<-i^ z{fCfK9(tNcmuG z@D)`vZ(l%hLby@sfBBN`+o!czt{7?D-|$mbIo3O&9{!2N3e2o5=qLtU>pH?$fH#$C z^W+HZBm|D4#9<3(r7%HtK#duYh7We2R}e#FFGm|II0XB^%x%It76Be)vd(kEXTbEw zA1OT`_MQrv!lP`UKrX|wzomb7tXTImJfwL^L&?X;zeTi_I}qcv%y>F&)UARkm|v#D zWvalWK)o6d#h)Et8CZ3%8$4plLAe|YUzvTdzVKjj3A))Bn99=Z#xz2}3Ky8U4frt; zfXA70P3l-b<@^0;jZZ%f@VSK0OtD-3(_@!3u(ry#=qOLiPm*sxW;pVSQkCV_-_(_?`c_SE z$&+^i2O%SKf_!oRl)$m?rcc>x6JtBW!=ERrz0zoZ6R6al^VObnYPmFkZ9G3@DGr-} zVdcyJa74DEI1zUnfB%#7c^)2~X%P}H_TA&<^q(0DXlwT{SGiO!zwZiw-nTSp>a2fX z_-FKZd4Hf8**oTiVWC`EJ$)>UJu7nFrJltq^?eF#Ve(?#6!Ho#+r+Ej-y8~~yCm1X zra+P!&E4!UUN0o|0PCcY!8f+9%hr9gYxHO2`PFJ3seA`Od?pDGhd+tFq;3)>*scep z#ORO%*2gPK;uGzlIGV-RA6n%FlZIAE)3KpnIwD;Vw~!4&Mcq`uBEN7>irUf2UmMMs zqw#Y7s3;Y-H^D_6kwDXBH*YssY2AW`DV5>tGr|v>UxaK2ad-Vm7~iOWG=;e3V*i5m z$~~UH-U5SfsErzgUr#~gPzZzAX-cbt`Jg0?1zpCpn&22|caO`Cmk*rEG~L`l zieE;Vwm zCap8@44pq~tL|JYDq@b-_ud}ceX2JuxWCg%l+b1^$T_xu04oe83D!l1kMxi+&Cjbe zuW|>pOfG~&vjL;7EdWwd?iqg!2Zp_>O;s}R(K(llHUTJ*Aj7%*&HHNHjK{~#Pw{xD z_*FQa^K<_OqlgZ?2kS>4fDkm1%m@WL%AS7^7bPaaW!qcsE*;T45d2%rhIMLa8SmZk z?C6eS$;b%B@=4gU%B>$viLL}pMxVZ7?X)BH)lI}ogoqM%u8pO^>TMK1N|U*9AKKv$zLZ1-xg;rRAtANr~B0I~GfB7K*IFT*GtH_vF-E|7$V96Y@c9Bcd0tWW~W^DYrKH z!QeKu`kY>2Y(5`rT>YlkvNyQGSgd$XMyweObp!nzp9JiD+3yNdgfz^fu( zZ!8iCkOl$}!Jg~xpjULxoOckR{ca0+-8rjAo9e-4t^JBjhC%7UnFWJt?bedVM95EL z!$Dr)FpS%S93bnJg45v)AOpu=iBEh$ny0$=5}gTIm(1bb3>MljkcTt#y#lI%z{?eU zK5+O_2TA3QjLjHsC_~W&uW6hDwlg*v#oT#0-&DM(`*EGQ1Ledng7h#nyJjECuCx(O z{HaKT4CH}_EEUXP4hHhqUi9Igkex968zR{@MyCc;5KRSQ@ zK5KeEoS`#7o^WZk_ikWU5J8+;oR^#qTP_BxZVuq+Y?l!0$BNlx?Y4=FOVVip$3fa zF@CKsRSlsA6&&BM+0Nr9f017PDpb>XCfptyaG4_|KCOaQ>sAo^r*ledS&?`lt#(s@ z8tR&JQ~EnqiefuY|B(n}p}_C2($H-Cqn{QgbKl_pcWK>Q->#O+||9_F>iZr5FcFfM^HU@fx{eJLb)Ei}nLG>`{eAF1^~4&62i z_PJU(6svzkA_IA3Kmt{iy#D@iyVu{4gBjW?hQ8?aFm(zGK`}N%Y{<75^*MnH!5Yu2 zQ0D<-5QDz!%+sO%bG^nEiI3C4P_uyf<=!{O6Latt%*f;Q&&Fj#y;7gQ>!wM`I;mK;vY(=6-m$G1C=ao>cAT{$7W|ZZ zMCF4_q>{c>oFsAW^Q3IKXl!+J@01uDr1)B|r|I`rON!Iptg8AsOYm8-B+K>IIbP5H zFH1%yfjHaT*iP+F>L5W*{gV@GNo*7N%A|AVV|SaqLeZ%pwsTp$2QtQ8_9U}=CtCnm*bcJFe39AyZXL^K=ERhuobi)VdqcRwsw^`C&|VhI`!#9Lbw|6zjiWSwgFqGU}^ z;gsb)7p=NYcnw=52Mfr-Kj=6ZS9LiPH}uuMIl+sG2M4 zn{T>Lk&Y=#QS(DLXflSVKR2B@84fN!EXYTJVmv<7cC7h_097}np&-HLMoPO<6lcoYZVZ%VPRZ>2ltaK6hsMJz1C|K zjm>LXCyv5#P&?a_Q3%a!YpUV%EtRWjwu){Tioa;D%KUTy-=4#?l0}VO892Hfzi&l0} zupP-LaEVg6IQSIix496$Z%g56UZBj+#f^%q8XLzg4;&?&rSECClCzTG(DKvd`CVw7 zx))mFz$WjE{GsHj?ED1sN%f>jus0JD{AeHf$HS=lMjJG&g+JB142ZP>*G2Kc8J#C} zmmUVF9VTF|>NhAoePe{-(Y;q3chtM+txj5v42AZWu|ls(bA3EmVW?{^R1+~ykSs(| zvb!Wc)>fk*r5X{2JNCV_c!boKLlVyp9S}HqQQ}ntV5hdiUiF(^0Gl}Ke}E+<=k%a@ zdul#t*s|7%VkS{#?uX4T7S8!uw2MdP@(t1^8|L!A-Fo;>Rj(vN@P|MNWw4l zJ*vh4#j?M7H3IA;Jhd&#KF?Z-;zthN17L4;w??Uxb_Oh}0x_03Qnpq!t55&JR^3YW zaHprTzvcA5XzRV;Xm;9>@PTdM7xl_o7xH)lzwf`Z2_4c2?1ybuG{>vv*l)KPk=GeY zQX4Q`MUm`_;1wRmQ1qE+5~vmy>b#JWU`s3YiY5I+Hx%`EKzLo3O+#B_b0vcL(%w>JJD$$v`GADb~9w7myOvMvAr$Wy?+44 z3IKaUY;=4!`PAaGq%Y)pf)RWZ<(=cnkt>T#7$|bIlvl>|*U7?m$Jb;b#S0k6dj`~@ zXk#Fc2SDeYJKP+|su=_@Q`<aJ=0{-yr zxhrstJ*^ybud!&fe?={z^7fL|WF<%_H9_+-KphH#xga1i82k4UZVjm7Jr~epCL6wY z^T_ZZg5KL;_WccBDFzvbi8ghV%$ge}{q<)%29)-IVhAVoV_w|P-JP=*m8_81LUkf*Cc}p|SD`mbp3sa^OD}S$P)hA}}N3lCo;l*iO zq#1MXVr$i$##}17yAG#~q-YDw!LE*Lgus&Q6o6MxE#O!o%a4XGSxVz#R-13oLg+jV zMQFBeLn3%sPJ4Uju^UtJ1lgN^`mZMnY~~6idThn%*VhXpFJ%aGXpHY|;=)-&hl962 zzwSD=EU->>@Y9o9`LPi;-lsnO^38Ejst)ziE(K(~gYB|i&7F?ydZ}BYXwwYiyTfi( zYt6^3InYy@-u2Ng8c&|hvAcbFmc3iU`9ad;Gm1|Y%5k}vQD885m~YvqDRf}RQTY&` zR}W}u`)t_9S#nGrTF7K4du;M|4XVHUTY?qRUR$ykoP?D{926_Q*zZl93Ll5#MQ&>GFMevNZC8qr~e77={~ zGn(a1Y1m10WM9O~c=h0bJtY@E|H+wg5>QLDz1jKm2vu83rgppKs**)r(3nFauHjJt zKh#O$98TM6O!y_wo0q_;Z-^fFZMo^B;9;M=zItX4f?Q+#b@~`{UdS8q+3rD<4Gv8gkvO)+glMJL42$D`(12J z8D{k}h?Nw-&tMaHfTk8Lq_|JGgl)0AuTyXQiyoCY>+@`F{@Ba3bj`K*{wIR z%l#yr;$B?WEjNRLcMjX`uO;4xvnAH=fdMeW-)2hv~{m%I=7 zMWy*eN+{+UYKuG~5j>(uN&E-zJx3uTl%M=Sbc>%56$6@6b#89%(63EJH9PTMMYuHd zzkCJ0?-khmofv7|-$(`{r?@Z%&ic1NcKop6grb}R8{ivAa{|=Za{K$>XpDR`JgD;F zNVD-SgRr(LshV(@5U+*{3t*f?D0xkNkee}lBg3UmTmOwh1{MgQ%})3?9rYbFUE3FA zH|a(t+d3eil4VHfv7Em4^5co3<<6-ANdF;>Ry&gIbHkMANjBhVEWu@| zz|$R0u?J^6wcNdIq{hv|6CEj~ydTT#qhcR}*TVg}=h2}85n^3DwH*+|2e_C_aslrnQco#kbAP_|5Nj00#hp+MhXi!kGFLVCaEwie;FH3G&qgGCO zhx#uTTsyVkR_0w@u;;LzPlwgpVNJxthGiv)qI9e)n9kTC1?W~!Okwv=Ph2Si&QR}X zoi6~Rz(EvS?&I@VnWr(~7bLg>*7zFo;<^HRkM0;+tkKXixhCPV`AHspEiNxeiNfS% z=9C^R`L|zc_?}et;bNv$cUSGf`&l!>&ADRHH9BhF-$%^)%>)}P8@YzaMhF^mH57rN ztqhCe=^a0w8tr9v3#xB+B|gVDHY0!qV8IL;|8^0;(-HXJC_IXY+=Z{aLgV=s z4Cp2Aau))-9f2aat4lHdij{H5fL<)dvi_}S4K_HJv1B4tDA)GlR!b0VXB{i98X$x!e|hTMDdAvAfb)= z*!dJb+pahEWS@QCCMDr-3YG+^5F#A42<%FigD9 z?k1ip1RG%7c01}J9p42`NXZynP0_Vs8<L7}PIh%86}?KxavG85oHp&0g)9p^!h76c#K zvyVeeJceb5G;27Ik@an36XKD@=(O_kc(hp>ku1-nx*{9wm|*+n+Ce1<;^bDaI>wx( zIXzvfnFy;om1yb1h#U?^s^4nK?XH8}6zuUQJB0#15v_^(9~GW{m{;`R9+V(%M1%0j zhhb;MV&`xd1$`(b(*=)hS_41Hi;8eB@N2pw*<|-}1~7AGI8HLEQRb5U1XZezGInXV zzGdO!piWW}Qq>T?JbMIiv(03Zj# z4x$3kax8HEK@rlKOi#V_(!2LlpFw+vim#iRD)AXQpFydsqS!m~FUgn~Kb1RZOTnXq zY?UmJN|@iTujJG()BzD33)Qrg$;lu#E$&(j!bEJ1`+!Q=!Bz&qg*@1L$AS${PNm|% z*dUdTv=cG|-uKV?yfXdc3P;?B51x)I+yx<#?*2pLheO%Yv=dNKR#+uE`Zkd~DgV+D z0%lfR@d;YejA6aXzwJ+BEzR>%?pp?Os-ATut_&pRIAv9F@WfRrYM=b>k;TsZ80zv* zvXd<-wTb_9DU-Q)8>WfK_nppG0Atj8x?`56+?d1b>7k`uB0MNTzgEz$x{T2|2U8FlXCD(mKu`u$;|4U?mQ=rsh&$d@8UU0uJvhDWciML zOsq3d+&sC=oiuNB*J`Z7tMsmhcg4MO5i^t->~e?4@y~EQLzKVrt1@8t<&-y2{bd@{ zdzO>*EijqzE;ki*%(LLZh3vee{0oJmKrF z!}e%JsQ3Q{m~-jOW5fh_%I;31<@m(VKt{XECTP8wg&Y>v*LuDkQP3XCx|=>~xtP_Eo01FGkYnB^K+*Ny3rUu#qm~I6Blc&{q?(g*W9-O> z+`cB8usZZR;)E@kxn>BM8`r_{XHsFrlFxo)I6&k#fWcV7oOH7H)>vBY@>8mpJr`WiAw}rAr<_lb#(RTs=){Qa1#&Y>>#r(U0Xa*@5?8u z;(g=fw!R-q&h&wsGE0m*R^AoYP$vkHuCmbR#KMhnh6*~e*VK>(j}kS?kzpe+&3u7U z^Xwx9z63^8P82GmV>^P#t1#c=k3RVO*zlF4hwi{g&8+(cQWW54y5F4I{ z0OMd3Cd8qCxEz*TmB0<4V$nQK1k%zx8|%}^xC#X#f%~T`tYj=H7XWw~42>-GE)@L` zX8#t50Lu};S^9l6{r)KDG=sF)UBzmk15~tQTs{q7t=W()42-wqtg>czzC-G>KJoAl z85xdmy+27Abitt8mIH#rm5a5&p8K&6RMPj)O+hQ3SsgASWUF_a2pqM7kzFS8Uet^1R$YkD}d|+f2SbLrv3L4DsBC@QyC#7Ux zzf}5JWKF^cBp;gGbizRxy`_)R*JW6b?X00=mPjSETQ<9lH3W*KyIq#;%!afECX z2oTR}Jo_Sn3#}4&l~M~Vj8|aHXiaRyw*$!MiA3YAv-1PMfC>Prqza@v!yca;YbnQI zYF;U5{*$x)y#$fNaPZ9U#T}##{;PDP6S5Yv(%S>ai*N*IOK}G2ZaB}U?*)QeKS#hh zWH@Qxv72AzWKZGorRCZc0!klg(M7H1(%QEO@u)W63ImDTM+ePiob8=@uud}jKF+;}%(V2($J4})0!@w6%p8P$Ril}X z%`~2FLfp3fv$s;uAL>r}m=a`i&FUr;GkX_x+H>m77tahw_*v7%o4R%%E6EgPGn=rf zTBlLl1IwT;Ev>#w2aia+O zkXT-96%@MwzWq+Ih*$kl z!?9&FkCi6dVep63pH$n2##+)aHD}3)H5f3%NA#xRWHv4QhlGZQa?c|LSFW$lH7ln+ zC}zRJNGuHo`W0e|D=(4lDKgj30UI3bK)G*J`Z2pK13=9NUR~FaCoc1e<;p~I(#&bAe*vQ z6nla^=TQv=zLN@S%#boo%7l>I=&k7e+$2Q?x#!Gwd|9)9j2ec2j_&`#Inw3v)|OUU zY$VMrtTf&?9W=#5CeK9?L)II@#i=X%cQna8=iDVT7*fx?{&Q4js}4m*B7nsek_tM= z2H*jGK3V|Ij{`}c3d!uI(LN5X&wzM&RF~6Zj~K}hXbI^xhK9@uM^_r4!t{Ckb-B2N zZ_hLzI)ajSO`eKEs!6Z+D))Qhos$4R!Anya0|{>a?GikPUU&ucooRcW>rET1idM7U zb^e8-yWs>0G);DE=6;~djlDp+s#!QjvC+{-iUBy06R=u;2@AD6bE^NiT@v;A*h_w{$5th=z}XE+tARUBWbKp8pib(@+HolmC9b6vb299Pm%k4Jq}c@T{?IJp3InB*#k( zstl^jMYo?5dQ+QDpR~K+P01r|+Z@xJRAR(yZ3hOWUp#j#lIZR}`R95b+NLa#h`94w z@pZo+E=^NFZiBP1CCbryII7e5kQ2`TO7iP?2mv4q>y58^A-9*=2hbipTdImsv7(+U zy-XsD*|GbpYNH!US!mUbZaqzL_YgL*>A_C$An8-_D^dCz4@FddCx2=CBWx93D>e2e zHKUW*_(Thk;tFU%{08d&OJk1db2&s4#U&ZV!^y9m+9q|~$Lt@G+dmUkT6k1wR#{zb z8?o|I34k1RN}<<=LjndooW4Wi&(NMSirJ3pDjicT(C}^eZpCW(v3a9@Ty1 zb;sY8%wNZ!O{urss7jB!^x4Y+XnKi916ao1lFTW|36QAq7EZnF(WVW`fV#JSB+)&# z`|#Ac@$wZS?-F9`xj_^ug8DNmSkeVCjiCso*W}aQwnY+lqJEp(Le%&60B|!WH@#g^ z{N6q6HHnKp3`vkbr(fY#r3Y%y&)6K=$V)|1l;k|fR z%s%Md-@^RXW{fkTzVU=!GccFCswz9D#}pk^qeO9ka-W)jb{F_$KB%+<)FJiUel#km=97 zVGFML6hgHhsXbJ{FL6qU^Q_yCl|(ka7KtjI!RiBDXU7!89DmrGv`lsewcSYz;uOOC zuxI$%HZ?6+LtoW2Pw%{x> z$-+p47s#`i2$;y6=BEouHyO&mQAZu>wU;T7AJ3NMllA5Hv6DsZSLtBOA$L#dS7IJW z7|UKg8Pw0x)?|v&s%r z6!9y-I=fnKdHMX-&%b_?fa}0YRA#4Li!M^qJE7wmnfJ761sPG##qBn?9YJ7ju}5iC z{>=h%U^HWma_^cj81dkP6z^VRziM;fAjXbaJ_s}flK#(W0?$P2M^Uf7+C=cj3RG1 zWZM+qi=2}n*QpY~kAY^4)nPovALKF*2ZKD!8?n#~St`kf>9&HcZb<*cJ$?XkhUnBp zCi8zQDi^?hkp;TQ064kwZt^9jufZJiJqGX!Qgsa(Bp(j%f14Guv(XXS*mZZv5NU2c zVBo+4=kRECG2&@}1ibT5F`VfPM33P;qMW5Sux{o?at{DxAqX6ks0+Z6W)si{Vt{Nn z3yNT5J+WcWSp*#p3b@V-@z%0i0%Ac2_Y{!hs%Y1yi4s67*jzQ44cfo=%LCc6JJ}>3 z;dV@Au%V&W*_?9e5?~3QGY5^{s4>^zi4Od?V@d%aCrl{sL2gJNKek)uU2`RhHV%N8 z9K?@xxPy!!RM+Fvm4j8De?b1Guttz%I8_rNCbq)8t%~18un$%V`y5|v!nC{G3F*-! zox5=BrR$X!Nsm)?a;&AA=ir0M11ms#$JyU&5S9j84h*s^KfviXp+n>URzA?f%N$Nb zpsk++#vTsnnV7b$JFMX0Rbl^nMX}-MPX_XOI>Bp-0!)1uxwj8~M2TR-3PBhfmX>~a zGvyu#lP|_Eda5h}0SGTprSv!%fo3R@O~z@cr*yT|6Q;6?d;U#Hn$Kr7l5g}9e0{?G z

~Yrx0aPr|u2R8c%dZToacs8eapaj{rcH>jh4GHl)5)C)&Z;b2A*Jn(8nQ`CM9~C)g-hchH#Ow(SUPyS0%r zK`ZliAllwrqXxj!v!*bAG%-ns@b;^#tH4ZpLa#h@jyh-ti9^jkgqYv3jfo^dARA&+OKe`{O;`1R~-;n^3Vz)Rt1*fgAvZU27S zB{^CZ6ur~|jV1AdO5wt-Qc-STM|W`iX5+hrJ?YScP8Qm*)s#@b_1%tQ^`L;ZV|*3( z;$cK4oA)E*QTI3F74bj7$udzGvPNoB-F}(l)v(Ei_Z+{zk4OLO3Ac;$41-M47ncI26SN$5eg0SoZ^fX-NhhoK?4I{{EdcO$K&P zr)ds?cuVW<=I@I*Ar)7pb*+<2)X3O%A>PDN&K5qQ7?|=lX#MS0MeJxA1+}_vS`GI& zFsbw6yi4HRtAeG~yJ!J`=A+hE8sn~=LOIwgU@dRQ4F^&nvKs=r-r%F^mERHT0W9i? ze@k})D*vMj2Da0`KIx0)fPB&jM0XAOPQRaAu8&oqKuZYI9*7Ntb)J%~^XSzw#2xu* z1IORw3gEZi!#C%MP3*XjfD(td{0GP-Pb6vV-Jg)4 zb4|_D(|a{9naG)=aXsHe8SmK9A0&U$f%~)d!lSdL>{{)X^_^`)8jE1yW%%j(8D#rq zhZ9Z;vQL_vvCJDK;idczu3d{Fr_=&X@eTG<9vX};;P=4e6=KUt7ZwM)xbdaCk{ zJ$>HQ>PvJ)HqIoPNU!|Re z>3#fKCE&5rwj4`|)YFG(RCW<)7@iRiJaC4r)*hqpdSLg3NU|VK;VNSNrp!i~&pBtf zfV{EjL%;ioic|i+;z0F*>o5JZl%fUy6fQ=w(CYmfWz1Mv7tjedt*z4Tu6<_G*M(|HC9y}tF`M-WVwP-0+l#ZLvraCjqv47=ny(g z@VNl}%cNMw_Qp>XQNZ9t8ti_(J(0n;9)AOXy}Pn9$=0F6$h$k;e+C|iKo|h?Ky_`j z2A2fFXO0|`gU!Dh_ZA9{OW zLiz(skNFz2}txyYFO^4a?+rTlssvu%=_q2ohIr4c(mQWFgXRPN``DyraB!a2q6v z5q6Kt0-Q_Lt$g<;{N+whhxF#)2!N9m)REB@EP%Q{Jx(tc0d(G^bVTPJ1qS90vIB}h zXB`Z#_eZ^w?bliG0?S)Hx4Qf{!WX+@zFlcS5DGh(7^|^c5^tZ+@sNJsjK80<4ht)+M!uO8!!JhY=hpcUF9SL_e-e7XeDZ9|UBM3`PTPFze>rI7efVr4n*x+*R2-YQc)(x_0)*Up~E(TQmqd%Ke0T94Rg%go~PL*A)fQ zUf+sYC9&hfgSnRELQ{f%QK4!A*bM+h|4!>a3Tr0 zL2Vbod1<;AhZDG^d}q_9KA-W7B8uN61G{6H;I4Z!eAu$(W0>kaL}^e91Ls1NBCn0J zM7_sh-C6}>=+8ff>Pst8LCy_;s%gaIQHoX5$K}rEAKKaDyp$-$&3`^~?1NI>Can~D zx9b*RUT}~30*F)&Gqw?E8sB$8LKR}iggo6~S*F&F2}`5UrM zOMi^dal*b#T@*5IkpEsN6dO0bh|PN%D(r!)yH0Vt3lTMKP0ywTaUztSxkf4B|rKhSfH0B(og$VSt6s;hw8tbq}baeu@@8*NS*qi~4Ob_0;5$AU&m%Zf&@ z3lesP|Mond1;UOHQM4d`Q))u)=v8Tj*=Plmnh4qzz8SA!fDy;ivtiSHPqiZnAxm7A zzOVMCa~m1l3;x^jwVj#ou7^Wj0WFr`gEgMOQ}FT>5siW%qnUf@&u@QD+)ls*X5Bp6 zG0gDDNMSs^^ULR=uJ_zBIKh`1FN4^BhRIsiJ4P4K;{6|sv$XlN@thH|p6uwM!QbIJ zj>9@C@rU)B$v~4H|D}DE`IjV z>yN$&uX0j=yfvDAzZ>>%!+t$8FGhB0r~LLiGxi|R5`M_`YmcV0BnfkLAl)p1HI%P2Y5q=4ePluYsM z$P;04sVV2XCbEQ+z|(6sb>by`0Y+F~bx{8^drZEF@8dW?;}isHQ&ZX_edw-t08JDm zFWvoz`9th*`8|iHWLu2D>tbKfQooi%{`aIzH=vk|@HDlzSp{Wb+EV7h9zXX(TC!dB zje*1Wy=g+$BQGZZs9GNk_m)VtjJ-KD%d}dDDxG?Jc;ZC!kWgUz>Dq-;9sP1*^P|T3 zv3e0d`~H&3HVv`EJ7jp3^_0-hiebO)A6xhQyJ?QZ4FBD8t?=3V&!uM@k=-i2Xdo!)Bom1g5=6f;i)MRp@} z1F~qi&70W-fHShHu-#UmD}ivJoTy%)&H)Yj21giS*b@)90J9{2ArlD?SgGAX2qWv~ zKk5HrtZtiM@eTZ(usSoU&|;G8A*jZB?UuAefP?q}SAZrP(u@?4Y-%b=o`bMsa{B_S}-#58rSQ4xfP+JFve8d-ajx6)E`h%>E3kZoMw2 zAsVyhaTvbW%NIN!0z zOG8i^unSyh%)U2lIZ6M=V+|waiQQ2O6er9^|F`Fs3m*K`H}-5lHunOXy*b>P=#Rjp zEv~qs7L!1L8hjk&|K~-xp7jB5DKpqzf#e~_TgXOl zc0i|BUh=qP{#Pbz_IS8RiIKW>x9GbS_dO-}qcjHrLNPPr+n~jRDp)WaoHtWCQf9OC zaB{%+!}g09TQrV01aTQUT}m1ReM=FD#^I9udSA}r})ieNGJqljKKEGwdh;AWvua{WqL-tHlVAcoJO|9-_pC;?i=Hc zCoVk}iB^9H$OljcVWl|h{5rCt`2Ipo*k$u;lmh+jj7I#WDeo=!MK6$Gp-s|xZdCQC z%5BjI!Urk5Poe|1_nF$SjuS8Rq@RLLpn{jETpu+_}C4b8~DJxYf<8iee~QoUMwl+OE+oU?|e1GPtX zoZ82A{`f0Bs!AMah~-;sN_t#n)o~JDB}FU)Gk)DFE9*Rs4VVeAXdL3O%5-P%_Dr-} zoGr1#wM8~hovm7Meno|@$A(Lrl?J#OCgx(o)x)l0<$`EcGW4n(L1Ei`C$PnZ@Y=q~ z5inMabI`dS<}X`wp^4?C_RGW_HS|CJDE=Nj6sD#*J5TV)?z%OBBHRUa_}P?OKJ4TukCfUGP4G^pV(U7Hx@?p=MO|+ApcwjE zS(tM2L;3=BHZ3w6PfBHwZ9R?0dEiDYO_ne_vRPD~Dn{NAqGU!{f`Eds>11C=AzF(7 ziLxPF#TDcnDDYVxmKFsCD(OT-AwbTAJ)bm8Lj9rs2)aa856Da&*v+Iu*t*Ee4Un9r zg|iGxnnJ)k(&BIw%&AcF7+;m7++HH>A|wljBct!n;kTP&QlB@34Z_SAdrpkOD5WfD zDvc|Sa3*##1?%+fN(*9Eo!9yyC{V@-F`w_PD#mO9!yEDprK?odNq2X?#B%H>9GrZ%C|Gc@Xe^(vA-k1 zTq^J$8mwx*<_dpx^Ro;6#^QmGwRYC+NaJ6}%2NbB*;&vl)?$R0bxuW8-F`&yYWJQ? zV$%x7#>Vc6QP)Sw&S}^)y6Z>LmV>)rcnIgM12MP$x&4ZmNS>AcX@Pz?pb}Hq0^480 z$4pi0l!tfAhNcm^9cVV$H|f?xF<#zLPdg7024~H{ukyOdF=M+6w+%siPdEZ-w|NWf zH-|;%Ukg%SviQ{#`l9f$^6#oZR%Vn3e!B9Ql5nr+su&VoKi?Exhj|DRTfbiDc@63k z&i_y-ks)yKpL6V=(oW`q?#)3dpvg2$S|_j-2@i8NzXeOxDbe(``u_=SnK9ZytR<1v zgBz}*-NAIzt5Up|^xTOt%-GsfvGtX7;XhsM zf^wviGU2}YougDuZeMg>HHdWH7=$W4M4b-XQ8fzSEWS(I{Zt6Ghiv&iwVUbY(X z2Rrto^mT3k{*Y}l9@aWu@<+v_6}7ds)vWiPk7zG@4`66KOU#AMXB#y{AjX@bhE6s; zigEMHGfX6=N@aOdh3x{zBJTguLVYV9ERuqe6E5t=HWa2gdW1rDZ_WM>*Y~bj1UrJ2 zOsv_@IBu6e7DI=ABt8*St&e#BX8!GI^EtC?ZUXB(xLNRYuQ~90yCsML5iN+Tja`@P za1>vkKZ%$MCN8S^x0B3C-_yP61pv38EH8n3s29RO2T;BRf(g1nnUMoKCS+AZ+8$S; z0z2{n2sO@g1>ZK*V7B6{BJ#nkhO~-8V-R4ceEZhK@rkDBby202VUQxELO z9q=p~$>#t$4T6ZBD%jlU=rxOtIyN0Qd;L8b&%enohAtRPPC*u`UM4v4Z~GcUh1`(P zo>|(P`n>O>uz|hTgI#o-M`A65q6N=;vWU#pNpRmR-i6b_m|Ote29H5gu0sqmDH_Zc zHL0O)&9V6aF&O_yx5NKO>9$YQ2Dd=fcKD~ozl~NZGnmeUodWaQ63Mr6riErtP7N3O zQVmNj@U6fqJq9-Lr1CU1qYbg1c^};^V5v;Tf^%80q(lahQ?Ls*$6==E~wW@o&Qm%PC_amEm<-4>Z{uqnG;tGxAjB=DiO%KgySu1IVhg z%a>}7Nw^8t<6t(|o0g=OK#!C&h|b^$zHprr?fl>vq?KbGTQvo^dlqK6x(xA+L&B)! zFh)SnV})Vb+C|M9kfMwuJNjODbf7MU7hX}_YzxUAi&uwvp%{46Z^Zy=OIq9l9tr&* zg>>BH`}h&SpkNatEf4O=dgT@IEnuzW2wM0Av^yiLOMx< zs2^#oVW;4?LtBkh@Ul&Q+mcZ_66uW_z(V%!mJvM}sjSdI?2Tcd0#4=LKMIkTzC#m@ zy)rEPyLAWI!MJX4WpOXXK~A%m1FQJpUI)*~mG89Cth9P3{_C-3`z}TGe+h-zp2}J3lqi zWG)XR=HKMCqMD{%4wxQnPtZrPHK%?sd*``3zrQRT&e-vgQ5p&RfsULK)*Bh;vd5Vt z2!qutVioSpCTS7{cV~0K*3a3U<64EmbLXV=t0Qj=ulDur{eCDkneQ{+MLCc|U-a{6 zm#7vrrVR!P3Hf*zk=N{{bIprq6*`@5fY6(RDN90ht%3vB;sUC^eI)4q$=$_%KklG5 zp17lY|K7tLKR*@e&Cxukp>!YfW$RxbKhkZ~`{k)+Sr<@JbO=r6F{56k$eIwGiPh?D z4*!i(2>$^Xl#dbAu&StbEPfM$UcO(pymzb@1>mF5zL9Di86 zwn{IyPAxwD0s>1)j8goX^;Ou5?m8Ma1h5s!n?a-r(WkE2B){g{y>YWa`YI_!XSK5v zCY!bHIKOsef~Djir~C*CVjJR1DJ|T_bE_6|L9MRZpJKb&Yjc|KpnsEvSrTTyX3oq+ z0>$mh_K@>3mEL)ZSk^RqEAwtCpHw;XX8bHuF`RW(=%*_)0%W*+l0BNk%4G?q>k`t}9hg+1IE z)28zz4TulO1`4s9Lm$_-toHbF8h3u34tJUKMdT+SeoW}^*sQ=D^oXW{ZW}=P$M+pm zaQ`smfOCpS>mzy%&`?e-0kf#Cdzz~+)Mu4}Tp?6+` zN%?ur{LP3gxfW1&BD5^B3Oq0 z*eCrWJp?kCL5zj3U;D7sQGXiQ5V8ca76{HhWjZkaofP=u9B1fDzH)a|5Ue>vtj{UE=LFQP` zJ*rmyR#9kzcLZXfn%NvWVD(S!^p*?sN)*3V%5Z9Cou1HEXnm&Y-W^|6}aD z!?|wX|6kH{r$SbgD0`HVQC4M-B70|MXK#{9Mn-r`$S!;DnTo8)&Q{2#WMtLvywn@_ z=llJ9kKaG{(a|m5UgNsX>pY*&$MY9=u^GnTo2Nytn+sJ=P|*2>kn?YzGwM8{IN5SX z6Z1F^kOL36zgi|8dQq^V%>V;^#&bIgJ&=5Eg!kAL8( z{Chn{W>DArQHGR!mpa!quh_W4*P2vt@cwzJs)sc> zkJX^*DP%q~gO6#wPD@9v>;!Z{Xww!7Okx&J~24a~u;f!1v*81uW|0d7X}b>=>6P+r_|Q zO!GUYg2OW;zWDTS^`LbE1){xqQAfz*)6X6tA}QY^cah*VE+Li4w?Hwf<4fKv!X1a= zh29a~OgR2cNJ26)i^O(P&O(jh4W#}QY~cL3o#$^QQG5IXFBu8g_+hVn0k7LzUQ&=A zUFD8sqnuFWxzae3OkZX1HnKUj2FjMY?Aveq*HJXQ+0$l>LTi#haGR=;l;f9ipK@#k z2QOBg0{ooSOVsxIFWrya;r>&2#+pIaHC;lp>WXM(I(*F;+E=-02MG9`A}V%mi+?U6v!8M$9EQ5hKg+%57}KwiY|>j?O>W}{NTMQpp4n$VaEFUiJk5 zUUybwrdiGg5XnFuy}7=%WKdZU}@Z!S-HQ>Ja zO+Z(qtPQ^ClM9rA6d+Gyb*Ua}gXP;MkF?0VkGL}jp!7#xwS$Vw(U+PH zdpJV)VgH{JE5Y^`)W5%wJhdeDg4zTz-! zqJK6M@ajg)+~DSwhwMSSCMQ8{@O^KUx+HJwaTu;s-L*9~e+x!U+BD@XeA8xhm~UQ( z@H9*&ovNxT3Q8f8WUGGw#ARh@oa47v90m4W#xtozu6rLq^Zj{Qj#=duoF*C|a_-Ah z=ao6lZ%-HJhqRHzy96{bTk7i=$q%s_woZj}?>~X=?gFLk&KCX>_7|d=I_clXKf3>2 z63A#Eq9XNLs{^B`HJPBhtG;{+w1}ji4O94!RmeED@Qzco#4WjawrAmLjiL`^q^v9UHSq8@T1MOw2k@Ej%1EtZ)=y-v%nz}Z(&=%cqw6)<+O*74 zNqCMFRg3S&O~@M!s4u{DB^*t!Hz_*h$YhX*lENPmonSQIhWC$>EN8=V(AjH9{N1V` zzH@FTlLCc*)pQQck6jW(XTmzZ2?=NJIJ0i8a3_~rH0S4Ie*ZK0@%JWM@!U(9YCFYt zJYT)1ZtM5cP+L=(t0MUeZ3;-~oWmI5f99DeM!MJdRxees;7IF!e!1U8_E-MNt(r1d zkPRc7cemY|MbGd(zSIX1``O${<*d^hL%rj~7HX!AKL*yd{$b>ezU=4N5>Z}W9$O@I z49ED|>z?4jz^_i0m1O13XFMo0Rw6C$Pp5u$80uMMv7hkVSD0HNflr(B=^C+ShV2vb zm#Ho;H}-cIO_-!8MZQ+wOS1Jg0|#iKLMzX3*zmGQ8l}mfa&yRFvS_x?ygy^F1ItW$ ze|n4}9lgRq?Xl|Vr!yY0EXU08@`m3naDmr6 zX<(J>kbCINf>-yI3a*g)fiZ6oq^{vOiv^=uxj4PH}&xb zH_qUT3EEeIV|f;<>5N8cl}TC5Lq{>zZUg!l%P|Y9Y_;0R+`O)VFF&&ZV`Z^AB)kLz z?1kub1NE9O%z3hRvwQci`5$EjM^?cQ^N8Dcb8Ln{q7hk^^TBB&{yVc7F+XC*UIuHC z-&mEPX;8p2h!lsa8NU{ zBL&_82LT!ExFxpfxcK7`E2x{S0q55!YuPX(HyEmZ6caxm{AM{Yct^pq2PXoBQsk}2 zbYW+xH|-Cd1GfM;J{VC%GxSy(@GXy3q0Rzlu1&tOaY7mNG{!*cjbN0LCoOH-035Pq z`#{N15NCGQ#11GEbTHDha!PZK{8VVt*&y;N;L_#xs%qqJfUqO^fxZV-u+uz^jAVjR z_6OS!Y+nvK(#Ib{hzvM`LOq4I&)&vjRA*&tA+|*lis~`&Nj%~6x)`Bw8u;5ibLPp; zh}AE{J#Y@0NHrtd@Iz;eRtCR!7&t@2@2Hux~4s=3yeO zc8b{ZCo6(}fX79C&xfbVh)|@EsZnxM7)3q-H>MtNYeuWUkDJOg)v_bNtWdjDAhkRk zCrTI5qRMAKLx1a~bg&O0-6P&pc3mTgVh>ipcu{Mp?xLGvnK(yBC)TmI2c(hp&CT|c z8@tA1bz2`RLXeyXI})RcK73A}Ru7iGF`z1nOW?Os_^Y*|wP7;KTB}fKfthy(Cn6;!wWE8k{8=I%G9v)7J?EHW%A1BEGeXozbMf z0m0P$?(4G;+oWhdfA&6u0Ukg-9Ov}`RgMhBrTk=;C1aSI-x03zJ-`5uC9NJLXDNdf zKUcq=0QG?YLh=T9;LwPJelAqc?p+^Upf>3L4u>dmNv0{Ga3!%2$YhYjvt+o0{ZtU< zjRW<+d?C%Tz668YGanCNt1bCRAu76r12PPfzo5VRxH4WM7n(%SFHBZj^ZO22yJN*8 z0eDF%`{YFc7epxN2B+3J>+;ovfskDlo5WVPjl;OSUFq+JNPFVYs6J3bm3l9N<|{fa zkB{80KCfBl%7{OQK4dSnLCa9LU>sCLtTB92Bf(!rTN)evoTa$mw)$B-+1}ecbA_NS z96DR)w%+h2JMJiUOq946^$=}Z__>chC$8{C*(WIQ8!Q8j@g6@Shw9*OX4uy`PHvL< z20H*9h|yi(cSZ7sdua^Yf_v2VP4tw8c1a6|GZXc@8gMj-}-w1*?_tCHtF#oh%3h`3ND0-gDR zvIALiDC5$F?I+lcd2{vN<3naw1~>;8Y#A`ki<9k|0)jWheRkRoL2T8LHRci zeW8!=^0Km#o`pr`;#8*;;=X~*wG`17I5qp!Z`v9ro1mr z*6{hBo}T@RyZYMyYu&R(Shqk;d*(<4g^@alZ;c&qsz&@7pJ8Hs@Tx1|C`m`gwPaK$zPpPv;xb74q_iz3xa1LNm{9f5w; z2rY|=sK>$F9UFWcXyeSJM3Hef+fmGj7gYk)#T<*ai z3k%YnVm;9MWslPK=?xyN{Nlu6^oWGCC4yU+Dmz`BM5U9#IW0h1x8NGf!&5JAWvVnf z^nK2V#Zeww_ozQdtlk`-7FOu`Dy{NiS%H~p)yYq7QjrAVjv?moPiR;+EsTpU^_mf2--wiOgc0)NN=B8_RJR?P{RyM|0fqj=b~-#sWN3}7XiB{*8)xs0 zprzfSrN054M^>$DQSa3*_dK3HP_WGPsVG~q*T&yk+eK_IML>2rOWC*B=gHp=xHw?1< zrp?*l(_39qV>0h7YBI@dPUa9^31z7=IiJ@O(qqBW_vV9|L!)DRZiXIfM2zC}vpm-> za34q47@NC%v?4JJe-@h*M{*DrHXfat;3c7wgi&wb_wV9PyeJTTAMp9`_bAnHCKq+e zhq~&^ zHls{1P)C1gCNJ#O?fn(2%hHz{fBsIi8^2V{>dJ71IL@PjWRbxTyzQk1%Aj2p(LJF9 zAS`$$94&KemndobUGvoVHBrfUt@Y#TdiICGDUiy%H$NDbLn%;5o#WQ)a5BB`X_XwnE*`0uGD~=LbZ7PO&RDRkNwG zo{>My3;S(h{`uXPm&XBrj>!(+Td(fz~3)udJF6buR2x@hM zR2~H7fpy@|JC*vR8FvGA<>v< z+|9i(p>thYjcuJ!@ko~YQd+iQ49rKG zg;%E1vK}Rc*Gd3uYi#KEuOb`wsb7!LXmbNLW|9GQF^{vq;O|LN9|06I{*>!~t#a$( z2KgB|7%q^;ZJT})527F5k<_YKar9QF_!U?WM&ahLcC5vKiE{z{yqnGvbp4RpcUvUN zzSlT@o!&&>_{kwRGy=AYj&b*=h)yC4B%JLOHx_4eY6g?enhBM6<4s&(c?PH(V8 zZw9@6Lh<$tHR4=QP=r5}2zID#s1nL)_UUFeBW*gxGhrcAKuf57+pzC z$cMv|e2RS1$1xQ?5WPCAZd^Hg9)H(rP60>1;5&tM_UFzV$L@=5PlT=_C*c_60*68u zfZJrfQye(vFCaRRYijAKq{91gn6BXs{6elXz92?2 zmaB*Gp0Bc5WpjLjJr(PXncIc$^8*+VZu#4Yn|vd~I#ythjzu8!6F^%TxlgqHX#$0= z;#47{AmG~(_rjw1^MNgVJLC==-@><}OkzN@ONMRX+YuukguTTRU0fy|L7?&J@06nS zhWas5HUR-z{q14y=3 zVr?*CPk!`<k*B{vU4dc+;@8=rNz`hMfX$HeFs9W? zy*znfbj>f9HATq?nR^j(b4xl0hn@r_B$Pn$o2wuiPXssv92rf?*!3SfEx$aY>0Euk zJb8?mE@#8R)Gvg^`Ltz~M2PEphJ>d5eBq>eeuW#4%$Fx#ufNQ08WW_UW};t~25+$i zH7Cn$&XAsQ6icJJCfx|@do2{0P|9=+@ENABfmU-hed5{o$}=*$=8#5d!&)X?ZgAjf zILU?!mriSP4a%pn-nEf%e$pLOx<=bm-jg^-_PctCrBrQdu%fzA>BHu_*EHSaDtY;v z{hTaA_0FfJ9U_xlnH;2gdY5t3ohOYobXz;`GcP}eEk>&#b?|Jl3My^{Xfg^_?wZ(N z{uuczjPv56YCz_LrP_b&^;;{#FVH0r=M1g0cl7v8MWG{v-y2sr9n~AW5YF0rE9fT( zbva23zmMQ|i@%K1*(`@|D0?RDEnPXu7r{Kz@cJ1W^OPh|r@lmWRBhb<{JnF?3%d$U zDRBg>9L==YJosttV-Y!TwYo)jzH6(x6nJ4K)NKm}Fvn8E)KBE4 z`?;x`_W5bvJPg0@k<3fPZfRZ}zHi$syNnVG<}l;_56hA?Oe0$Z5sH@)^MF?Ev`yK0 z!QtEGVT{WB#hDnzOv716zkeFF-D5t#H<@OZoMxs>qa+ulZYG>&miDV0bCHQOE~sH1 zI#I9TnPi?NFY@~UVdF|F=8jVxZkkmkX6gM`{lWqE!O(l1!X{3#swK1i5kuZdZVGW% zzY%1(5tonVw{oVBH(8FrIzm6>wxa2dBeuQrL0L+F09IUGA!P#k&BhX_quDTMm`1MX43Kix zr;9#xFX+cAsBnX`+5#@{yKgwwH#GRJw(*HJU`yfmcxv*UebNQ6Bd}YKiUabUJ?jth z9r9x!7j!_Rv-!gZJACEed^i9R4?0EHhtPSCbt89nBkEOx49*&u zX?3UBB->#{AA2gr8VoHql>M$P?v0CGongt_BEAc3w=P1D;;0_vXE-I}lXX(?_iD=+ z6l8=zdMufo^o*f8TMA*ez>FI#{y_Ui{cfkBo)2fs{w!Zi0) zw=%E5!Vi&N)2T1HCtsgay4)fEQ!8Q9vbD%H%Q!NlFv{TL>?MD0DUa1s;6aFeuR2w! zta-aWKMCL?o%jX$1AJN;F(i#P4e@DShU3jY-y6qsUb%cBd356o$(Ik}Zv*`IejJu~ zY}HLvwfO!0nac-J_RAP>knlV00Rx)rzUe%E9mR3V8F3B5hRr$zt&!6WSqFhEVmA&g z>!T2gBpmI#RaXy&e}KAL~&Iy`qGk5Aw)QQuLAeryuUvmysHtV8m zj&`g0k^pwwNc`vmby4DwUNbKTPo{6D%F#E7j6}j@DPhn%Dbwio`5{U97|GZu6s7n~hOp1g!3*>C z_M~&Xuqg<0&q-mJgF(wH2_UR~`3Wx5E(kS3P~Swz3*`u3hG(C+RCEU`zQrZ814T{< z{H9yKw2&$ldeJo<1{BZX*)gY1bOoji>u;Z+3$;3P$ z_`O|G`~2X>FSqy#k^=6~Di{({wqdvXdqSl!Tx5PRK$aGx*Av{uQl=Z$2xg(b2kV>nK><*oKSMLb?1md3fM%aa&muDE@9DUfYK_l_%VXdD~ z6=-ly95_lY9?q(HBh;3p!RgLFPk$BtT2&w@C1J+8`b|{EA+! z|J87);Av0_1PJsbfrosxTW=C&B9Hy}@gQ{fsE7U%+DsCtX?8Xfa!+|1_DYy$15P3= zk;`^rv_99kowi7tkA2urI|X8n8*|>DWhhBiPPp%w*kBq@sq%U7jBJs zf{m!#U8!!xvO^mKCQGczyMyswBjXm-s-)P57h9`l$S^2pM!oZOeI*xqKlX zLz4BYElDZ~{OIv}wQV2Wh99gqs0(eBcn-qxmY8|Ni?FdOKzbw{kkRylO9^SEgIW{B z{0@Z6oknbaP%4RoWKR!cm*@wpcRX6J2ZM8ooho!|sM#pLU+t=_5Njgjc>7t0_u`Lt zrFU9a3}c_aQ$KC_bl*YgDG9~p+PgZuhzlPUDKfNv``tNAZh=mtcfspR_@2Ewokji; z4PpBfc@q^XoF^`P^t5X_z{n{qiMbfw+#>$g==>`iqu?Va&uuk1?mGweQ9!I1gZ$I~ zSj3P0S;Re>FInPerad1_+!vxUdsFJO%ai3Qk)ef(m*$1N))?WUz6*oQn`6^zN=M0I z1^@ovR&ZI5QOvJ0qk{*+qO>=zzB|PMHZ4Nw$00dD3z3WVH{tApz*-SZ*K4-@E9}gr zk{S%+d+0{u9uMx1AX?{%{_V+XPa{R2Q1d$1wbjHXT^mEZsqfLM{?C=PkdM&R;ZPcB zHqlUfZ+6)2A3ra45kc*!jW%jX*geF{Lp#WJ1IVYD}+B3Uv_Y``I_#O9M zBB!wHIvBI0tkU?lQ4I?7YM1Gtmo>NbSfD%(!1K$5&T}6e7StU4f^$8pAo8ij&!9;S z9U6+k@XU{CYViw;G=JCgk{(h(2eab4o+rd01Hs_~W9M|hSBe;;P< z8uwo!Gq$og9M~>y0w8||1=;Sk$6duZ^%_TAsJAvJJuuJmxM$kYgUMF3 zfx**6P=_5g&7P~~fd*oD=3ZskY>4cm3AY#FJ)TYQ_(b$<-k8Srh3GHi60nC@L1I?C z+_NIHdTrl}h!a12A5X(}40K-~w>a@T5;FM7KVc|dM?tU-6)FJCK0x3g!wrB z1T>WD0Rn4xoDh#O&inM!4})Lp0j+*WlkJ0-2PE!U!zNie(WY2_)C%M-%vBp>Y?_M` z_m@;=MP78#_lQ7;aZ2%m%cmUC{bs0`T!CuZyT4x};*zd#d0P$B_(%>qK6>Tr zGrdGWV4=$0uRBjhRLHPHshr| z_W;FrBq`c?Dg}Jg=^vepDNsiaF=2SY;$ncJgMf2#Ge-^OEj{ABy6Fe=;L?5|WBxL3 zT7eRS_n;n3#`Hl!ju;=|9+0*CG$e#(>K;_7Y5GP+@Bc-LrzH>*h>!ucf#hKD5%(fd_H@8Ft;-OszUSq^kydvEvCg3%7UUJJNEt>^~X?z%ysS14Kh zy*c!|gdd-+Y=&Lg3npzo%`-nl!mANkkbvo~Zwz{QagTgp6_|(rB!zWVBV1>ZZ3*;X{)|hZ*h+!XGMyQj9S8u& zmisInn#G?8WyITcZ<5&^dyin70vq41r6^@QFDrzJz>3y|)IuqqjrTXY>w0%|H$e>jIlIOJCdJfqkTT0?TfuX9F5K8syT06E zh(&gJ!l^tW(X+<1Ou~NVLW)aVsTUaPPG`5=S1!KylKT{od5xNbWjUP;^9aj{uuKi! zpPJ*7d|gmB=LRuadTrM zo-^Je`vaj;z_QCU2lQ8IFBrDq51O9hKG0ooB_Us7nxieRW}lykG9E;+#9MA!9QkK9 zJ01wTNVTmrL~!l3T47k(ilC9s_b&|3-o25*D_ktA34m$an$jt?>EjB_{rOG$?n)VU zMgF{A76Bw3-8Ku5Jw=*#eyaW^T_3E-RG=FPlzaHJY%!qd`Z}_I!9+$u)SvV$-(@)? zCXmglA<^GB1>_wSPq+u?T;JEF0cuiKzqzpqKh!C1op`lA4=Rb48{?Ea?P01*p8ALw zp0oZaZRd?7_k6%ZhNjz6fv5TP^*put?1tV#^phg;!xo+S!sh-RZ~h?vNP6)dmdC+R zOL00V(nyH;@}26-Hy(+lhGMC5UzBs@JvI`M3pXCW5`iqSf0#* z_f;zGX54)gn6c0%#rAL^v{o3YI$SNi9PDfhA>W9#YQvi3rvaRE;~RB2zkb+v>z%bSOIB1U{uk4BIzC^9%DIAv42 z;Mx)IW1Noue57y+Vc|2CZ|odqCod^Y6FqZvdrjJd8~=CTGf@uS;cV#=)GZkPp!APO zASD6jKOM-UG9vdUH;WwoiQFzEs=dlpD2(2SXNp}n@h7gewe>qcSsEv(b6k_^q6Qb1 zUh5w18sW8*Auw2V&Y3sLSkKg~VP*VG#xG{q>sv54=y>C~O07CI$KKzO$j*t$+@lO}$TvT4jsXbdFo0K6v#0&z7j$V7e%B3=f3;Az-iCr4D zgtKn*>u0l78u2y0-4?No%zO=}MjgdhcRs)5&i+3aAzA7lYZRArCxj!9->5C<)?z~F zUl>7_7}PuWgC{9l-VSlyskc`RXdL_r6LQ3AK=fy;^Zz~nds9%*XT|YQa`+H?S~I(P&6;0Sl9_=%!%%#hfx3H-oLgg7-Jt() ze%Qj8x#EhU#yOevs$^BB7c&biOGO<6@o9FlFDu7vUOWb3rVU1*sE}W?v9ROR0Ln#9 z(3Q}L7oaIJ$8#5K^d#^H#LRc8##W1^J%$|cPdUr` z?_)4@%H4j|Is_oykT13?S8tFtEGrnqj7#TETYL_%@IftFIaB1RTk%!u10F%2h!!#h z7-9X|$N>v$)aWP4>k2Lq&N ze?Jm744`Y}Q0NJmp~Ar@(pMlIF3ASV5Nq);4uU%_s7clDx(N)>nqWU)_E2*Bqj>5P zii;1EKoHjDxgePHHzEik!Z$WO=VDUu**ECy_Jnm4f*>BWNsb8ZGZI{%$=i}o5b#GZb1*N1|yPY|{bOc8SxD!=5+aqGJMT zYj2s(-X7(C&?dDe#7zL|&haEbj7!?4SU1l%{AcA1+Ukl;#0A{)$K41mK{?!!*+u&WzYENqD zo$HDb^YRp9_bnl{6XJE&lZX}hLD&B;PaSPg9mD0YVr8Za{`wyHeir9&M18W$84MFi zpLPmym%$E)TgcZ=_AzmpK`nEPN>+_K=rj?NX_O14?5QRlgT zMCs%ht{Ml0LC;tjEV3Ou^AuOg1YNA4mVIb}B5!W?f{aB&=h@Ax@2cVLIg2V8&5zGO zLnx-fZwo+m;!s$N8atJ>XDI@cCRJEeb_3P<{s7gztt2V2=Vcyt&rj{g0@LTMCr*3w zG2Q+7?Xw&g$}eh!BO@*1?8R%b`Md+!WNBXj1kYZw=qF{=MTIRG$sRRx!xV=v;0l5m zZme}z5xZ4QXO<;`^76K-ZIzg&bp?!Y5t|X2_7~tnYUbwUjiG}KX#XCL1=_(jK@v*K(Mw_g|?JBXRDI+i{)d!7hs0YVkS7<(E7$}3ccBUJ?0h%d`BnSFk_MtrKb4J31ne>VgR zN71hn`j@iIqNyHTr$gEe#Z-l0#A`vba6?X3SgSi6Qt}R+bcZ$sQv4)+)#-RSbvhVn zRIhA&du#pXFkP{DdyBw&dLNTqxWyUB+6xee!8%^B+`)t{kbc#(5-cGVf!V=?L21fo z$kwm|D!Gx!`l`!dU`nTB1raE118W>>Eb~|sM0ZtGzR*=7 zYd!#%c!%_kL^C{OTxu|yi$)S*q?<$GfbykF@CcRr;coT~{ZiVfYgPNqKbU)HfaDG= zs0)!p8=@f1aCn#8Bf_<~35g-$8{%CeV>CMl6@7NtAm3C3QmQ^X%&@&O3wSg zLxe^Hq*M&-c;nb_3g>HYq+&j{f2Pirm;$?tqFE(-h+H^@>4h+b5Iol6*DO8)v-$0a zE`b#Mt1z*goHbRK0MO1!2A;-jkKe!Mp7R}mLkSxgvu?zF1p}vVrK|B!Wd$eXP`R(Q z$M8jiV6F~?;bSoQvKKb|1weKQkb`c{?DW|nX=Q6Y@+T#T!uBbGUZSHv>|&?JVA_NM zx4#rLZsZS$K*f*D)A63_wr6T#4A$Ct%ZT03$jG~ME75kth3&o;=fQ-SlGs1ZKVb_) z=inGb;>Re(*tSU^HV1@f-+V^c-u>^&ers3!b!P{9VT;;-deFrmMaWWE%j#i;apBIp zLBnTcU~p15A@8@O)T4jMsMi<+P~)Vgo)|OE=CLaEiI;BIrTz}=H~&DA66{KQPbhV( zx-wcxxhA&~3`V)@3ghv|bO5dL>+JmZ<2-k%UiAAknzC^BtohJXM^oKgz=h3dm0dmu zciQGnFD>4Re)m9|Jyg~G`cvu+0So(*_tTWx_kM*0&5!lhIYnm7i= zXDj?mg5QwVRDpV_YQAzE1p_aDn?@2Uhd`BD$Y~HsP(>li=rBumLFXO_V6#s1+>Hg* zKT~QYc{xIGL7Sn0t6PI+>fHt;){KG#z{)_3adxmm9aYSbz~&M)@&mP*f(x~wWwM3c zg*7Ko;qxwwEvIK2Vs3Ld7q>cMTZWQ8$XUzLa)n0B>4Y0X?;znLh~(%NpdsVDx1_;G zi|Me_Xt3BQ!l%O^sR2#2^A3GTkbQ#+4!BMuQ;P^p%uDY9t{lH#@%&a_-etQ?q`wC+ zBr-UlsQCD?ClE`vu&~fCEp8{2sa3%zEtX*1f-;xEFG}GHIYY6@pVL={K8tj~@)AR$ z!733P_tvCg7smW3PSO^HU9{@fHr&49(qD}Ebe9aD?Zk0vTjb0`Au1G3aqnbNM1jvbdWu)oUk_BVwfip+pNMWI?kdUMoy9`UBwgMD)8_aiGSJXh9Ii8Ry6 zQm?oz2Pq;nrp7spnr#ihua$DMr}Tl^icv>|-xSZ{d5kP((T@!*?GZIe8>3&6w{k zTMtsuDet%2BkIfgql=2Gw8;e+Dkyepi#HFlA%y1@uY0}E=W&fmCfilU8jDhe&DNJ6P7m6yMih&j<7(4Y58hijKD=Dg_ZQT@dy>0i272RCXeEc0D|NsIgr z|5A5_+)b^fWxR?cYzUm{Cp(t1N=^~~Bvcjkl%l}ySbIGEX^&@Vl3M;0zwdITmhl>gSUZEQ9s%|#yC?MM#39{lsj|l6%Cy6{On@FO~wB{oS%>&w4|90k>#u3Z!h=CAn|&;Pt)5S z?(03FspH?w`OWG&CJ9Zj1-g*)ar!*j+W(l#KXlfE3A1&dc~Qv^*alcbGY^8Y!eJbd z>QE+=tvQH8*bSRV$AeG`F}>zW*LOP*9t)O4ME|M*@n-|9AZL(iEy`uC#GQS&)QSN? z%^zO&PTekE+Hj>^*fI8vh$GV)-T)T8{&lE+l3=turxNW8qVMd>nP8$q(UkFzbpNN1 zsy_(=bKA$3egmcoL{{~WPtYX7;E}cEv}N85I7I8Aelz$ZGo3YZn1Aj8PCx`TU+rme zr$sak*VVCbFfr(-pbQ{rFvw?WL@bRx${F5yBND0in$;=HoSfS34XDxtjhV;0UC<;? zk~;+Z!5nqkdcv7aaoeuj!V%DIooKplCpyJk3xYSvD)-G%FieErIgzdCuuEVX<&|wW z-34Z{@iKZQGg2)D#s4S}}H!gX5_2ZU$a#6<> zF%O)>)b(~%L_!KLPk4oDt}&Awb>sU^%;x$?!;6ni^~~z0zHDANp#mgLuvulSVcIeb zTv{|jREi^yEQg{~#ue>?0VTIa4mg><9<|}KJ>Ig1o}Jus_*BBPHeB}BM;-W9_K8L< zqfUa4;e=9&mp)}~e;Q6G{C)7V$3sPwnatfP)fY2{(lFzB1+$PR-l*h;cNxqP);B7kEj@p$n0;eZ#37j`N>eVH1bLWUd7p_z(HQfX`unnnWy0L!@qX7C6`UL*D|DRe&bO*wP!A1O|+L+8YN}IzOg=C9rtg^K_iy6QF zQ3YLT(&w17V265yr#30Ehc2NyAYWi4-98ztRM8er0@GqO@*plNFzYPJ&I=%8a`2-x z{;8#C#zRK=@aJ!p_bwafADODWH~Ts8<#JwtL!(h=tY66Yx&eRqkN?x>*TNmA|7jxr z@O*!k_gMJZ>FdhVUhQh`cTK44nfv7D^{3U0ZX2>YG=^t{m4SP04h-*`Ns3M<8td*8cEB9AP-9r;?rgRvFiHo*SJxy5*0mrOufRCD^WVU|e`Wya;MUE`9TJ^6>yezV(J*79x4NS`H5?hy>HmVq$ zZbUSlwy7U|kz!!9z`-i#)+}9nJ$O_AD8{t}q=N3{czi&z>&p<=)cOcAVw&0Zza(s9|V!z zA)H!DEke2db2|}Y*p>OOOf``MzU^ecOUKUMj{Q!5nqdq=fJr`Li=zntpe^Jct0gE_ zt)-F*_CpGa-T3pDS!R~2oY=(}xxPHbLCEn3Kmg`xfU7;!N(BG0ko+NO=sQ}JPSW-i zTP4g4lqo}H%m9b^!<|VmQqH~m`OQb(ZN(`15I5l)`GYq%mKh-KD#R{fmte(BlB5FJ z(+5K+jG|!5`bwek<%R8EevjuOFvTM(-NZhe-Q)QN_}-ek4z&ON#o!}RARP8yzjW+2 z|3ALqHXxw!IL_|hIYIor%6}V$5E&#}%cB$g-!Koy>80W7-KO2vB zt|vXnoV&f1|CJ+df5Afpk^XzQu{uGoJohnS5!~Nalfk>sHF!OD5nJZ4!s6P09tupsh9+E-RbKyWcI3#9clr9${jd-V$ z6B40?EFX7uKS~cO2WDZC(y=30b^%;Ral}cT?9}cc#yS32=3az<&gxxsAzz<39A(|W zu#n}kDFxh#a|I@yPhpR#N4-X15iOmttKW4QgG%6e=O~&t=EZHG-fPH?snJT3`7bw= zUIeJ^%p#wN_W$Qq{__vumx9}+I-}T&-6OYe!{0BEKL=QHqT9EKu=npjeiD3yD1U<- zWn;qs@&%H?@WjYJkdgb3e~5i8>HUzA^8fxFc*vwqToS=nAG@vv`s2S|mjCUC#1f$= z>i@h5!^ydyl5;y=&diu0$_1xMM=~>)U*X@^RqV$=umakJi6x>eh)%pU_iU44h3<3}Y2XaFqg zV=|VS8;)qX!0t9PTp*AbVK#?ZQ!jzZP@x5TpM}#Uy8G(9Z|~y*l5qUI#EiZ9(2GM0 z{rhl?g{kz#)Mj!jT)c`pb-hedU^@{=NN=ed?htvZTG-XMwyA6E!VVLw%O^^~W?>5( zPFE4|k_R$c$e{gDN_?tA4Mb@#Y#x>B!6*K4_s=ITCi5SBOE*8@;MF~{KJPVcF^E_@ z{B$W2%wu%z7dl1ND><18b~$Gk@%;@LD^iY zr>j)G(f;U#jht*?*O=T=nX)Lj&7{aKw4%T=V?UQ4W?oRmGLx#1(>ai8?j)*|MqIoY zCAcA+sFd+~Eh6^LGo0YVj(2ucusa@xts8px9IPpU*B?*+Z_g825rw3240`KP-_r?P zN~j?ij*7;UmXek%mL}6c6=P|+^25JKl4VW68@!DnkaLBy%@b3Lh%){<9V}g{<5%dS zF5JG4nBOSot)UJ^DBzl{dC;=6zwPnt0P$?Z^UbRn*mp);6a91)^;`1wxhcv-4o4!Y zg@cxx8z+@D&Ytnu$U$li1PiyU`uJ^_3Ice-5PK2ybx}6fkkd9;c*_Nm1i<>0mCVn( z!;9$(>n?M!T>abX2nHFeSm)Uti-{1G9X#tVErxD#U^gN85FyOCKHf@hk|F=TP7_cl zia@n{qQmWEfdDHD2FY3jIPE7JIwzXZ z0PwFjnCPonhiV`&VxWcbybi|nNzf6LHG(4yNYO?bv_bU)O0=>?$KMnDoEP?ZYJWLD zK2~IQAgxAq$?qA1dOAkM)=%;E33gS`oK2V%q_2tHQEO`K5@7ol<)bPqp^=X<@A;|P z7t)kPfK!EgO0G+$H$(X7Q;MJGq23|o&S!rl6p=#%SaP!_ptX^5cjkTqN5evpWV%-Xfk`UE{acY@l#G7U|le-s;r0?8|q+?$T}Hg9(?uB77ml z+?ClQI?kNt@8q}JoEol*+?8W9FsXdyjksE64!>JI<>(E328`b|d{LR3ui~5UqX+|k zNPTN*zPz)>3Xvmhua|5*M4tRvvlEFt*=P0rMM_~pH@6_S-<_*fJ6#d>lhHxE-8k|8 z=Iq_QL2f;0v7?75!X^+-9I=46aN@GIWPlU@fw?XKJ?Obh`zd`9f5sRFb!-X`_-Ic3 zfBF)LBZe^|VzkcVFiQ<1ri-B%JF)|$1(qSi&dj$w=n0ma>D`durvL20@uAVNXa5D! zee57FUG1*=HFiH~!dg$e0rOJ*!3N$?eo z4WM;j1B}zanBxzzGCE`Il@p@61y_!tZH-C(n2)m z5_&H%Ar-0)5aS#SD!>MqSdRj%lxTqH2g39~qfV2P;0KtUBRH=H@-II__If1naf1HR zNUHZ48ak%t4+efARmo4TZ2~+{F|=vtV<1lxag|MzeIm)v{X0|SBw?zLFZ{vD^vZ@- z0L2)f*a9H#7=W1oVxmQn?{O4hv~~d^PqNeVHf?XbL9#v8|=OMaEWuZX+8Zxbw-p;FFzJyExM4 zU0FfOdv3#4P3jg0FLAeio>QU6H7Q7@$VVq2Vw{Zr{{P0(+4`xwSmF`8SmKe0B_6uN5{Jnw3&()i%ZWhnx0n|{x<?0A7o-<3h&ip5YZl1d`_Z5wDn z$?Wl>dw2jd$enkK#yD*bjKKPBi4qop4bcAoH(xLj=}s0WnM}6b=O)L+i$lJHJ(~Cl zX)#DIlUdAQj@_~NRJ?zWK3m1M;AWcpu&;}6m%xg9vMG>8oj%pEWcB<=&CLnA2i(6k>=6F{49H7k+pkcY>vs6~oarZ`#ZV2g% z`OIES?$QBbH^0oZ^6!DD$DkDXM&2Rqs5;f))!5V7*?lD(H-pcfPun!Vt(+v8OfJEj zeTAlCPo?nr@nasxQxx$ovmbcw+dj)n|m4Xc&}!3Sc|mEoiv?#K&rYhJSVhkbw^?L)ZB_@jWo9HCycK7Rc+Q3ZwNa$!gv<3cC=iF#TP z79c2X!Gy7#(aLUZV`BgFERkq?`uYf!}slg274c4%5N`U)AgNj_UZ4>+lrKOrs53{?i z*|uJpQembS1VIV%Y3vW_|8D#$T8UnU@W8RtGdO*m2+7?sIAZ>q4{D6o$GzmZVl)=J zpM11QgLvxn$?YBoOkZ=$Ct40>oE(C~BeP~{$9{(tC!Q}OpzC4;iCtOb=V!ok-FFZ_ z6kxPnY6==_XddNqO$+i)v|@8MBqI6pO&=IJ2|Ww-89iwZR#f8 z4vP?=YnZZ#ERDTJ$Hr9E@44THSE(?Bpn#?#?6T8~3UZao-vZgGj_ZYy?{p8ER1;tB zFxjWRM5ClpQv z1>RE4Uk}H>DpMX|RlRtHWu}fMp>d_0lGierw%j!Wj`I!ZK-1?fM;Vkf ztJA>;so1R_XqhK;Iu;3Kd4$(i=CiyPnlH~6uf%%P0z^jvAUcc?)$`Xi(lg6j3gZNl z#pQ|Z7J?@`h&%u$0Bw7=OKa%!U9Px%c zk&|QLevsoT`oi(cVX{kR&s?<29j$_8+IpTcmUhs|BVf+cjVr<3crT{?YwobweWc!b zztf?rCHiiD9meXHN{x^GspP3R4&x4n$VF>Y^4T8pBL!lPr6246Bke4pvRt>dt%w*{ zNC}Fdl#)tFD~fbTOQX`=U7|~+L>fUtx>53@v5=A!koph_DFG#=|M{TA+9$>t;~(RU zvCm%H1^Ds4&ok#e?|I#nT0p2lg3v`t+gQxa2^cZk%=_xM88wrGr{?HT{0r-e3jYeY zM>QeyQXdDQY+y8{;m?oanJY>4Q%dCMKrT4BFdpB4PsobQ?Z^;~aDh%g z-yTqi>}1CD4un}djI|20l#Is9YQlJId?B#S4PN)|AuXVT*f%O_Dre@f00Vk)Z7%IRr(Y zaHLD%{mS8N`*sbF^LZV3D|P!uUenwKpo|X7$^1qb5iX6q_elq!3~>Df?ivBFC>+!0 zEZKN<#Z=B<9E$JBc!MW;yzj(y=fF+*m!6gtNq9Z&-?>cnh-EuIbDC`L#iaUXjZm7o}qafnp;9+-wph{4N2S4 zci;A1!#0sR?P3oFcthPGyqUg3UF{h=HR8LQ-7dz7S_+pk%dr-}v(i1#XQvk^pllnI z41kzFFg_wk9t?y!tx=D)NTcL09}@D*)fKD9eTW&OWZ--ryy#Q5^I%Cq=>Sm+OrM3K zKK^~a*@D1Xpu^kJDgYwXnG0ya%jVK}?x>^iVw9p0tWZ5S2G|sK$&ZZvUF*Cq*z;H# zniMYVJGA{(AsZCA-I7Y5)NQ>p0XJ}QrfXBp9HC^_g--l#0}4^cI>Dd}zR1nR5GUZi z9!?6^eFv-qMj(Pe8G_@foLGlButR{NfUA06k!;w z{R=*p-y>)v#19;oiYZX5#!KGcBxzqU4S-6wSg(8kY_&~hquj5O3l{6Ej&1K|&9UU} zXPY80M7?Mux1ZAepw{XX7E!Aws`H60DCL52f4e@Z%L|)y(+Z}0d-lmv`+SU*5hFz4 z8Mr(;e8rUD+L0N|jyE^NFH$cA0#l1aU1a z4!IsDnuG^zmMpA@!m0oMP{WIG?K@?nUp**aiMX%7C&-^A9N^FG^i&I}I#$sAMi~eY zIGcM!?$=aI4}&=pI};sq(5+ll7KVM}2eL5Zs%Cc})NoP|oyR0ouogSxlvCKV9MxUG z$X$Hz<99D*kj*(5z8QH65bRHoc=DCat`V&1!wd=j(%9h3k!V`nT5WSP9UUMuU>~ak zwFbMJ6V5IA4Q;l}2(OdDE)l|%i@;*Vt*c9ki&$kDV7@XV%QARfMf7(cUK~*0tWH(? ze=^wdK8;Dk+4XVG+x<4>9;7R78V8c69MD{j0Bz&k)P}dVCWn zX@}{2?cEGvgv5*v8Sc%_8E4tanCBi$umxW}#i!4PH>fO?1zrnZCy!j(2lp47`#cnr z-Z(I5mpk9lw0Kwtt%GVoy7b0q0_=_@%bKLYv{vY=G1dMQ{V)0C++)4=i~_!&t4fAC z)vp={)rCIHP>*Hr0GyD)r}&Hyn`9xktLhc%51xeyMTU)JU~aSFcuJkOyp$33gN1aJ zOna3{YlOE<_b+m`P z+yHse^UU^%^B&K1C*Z$=m`Z)rpZB&IfQGQlv7GfP1q-($N2@Dy?nQF=8UK&-nIs{q zyX<5T#dHqb;;CqX3GM_U&Nsez?n~g^lT6VI*)<&tAc5!gbNq|c?T9qUfkBj*C9R=B zKpAM;JJ;j*&m@-{zb0~f7INP+liK264hS+xl2sMq4btB#egl4?wcKS`_zOjNJdn%u z;4qs3_qDf(z5fD(*#KhZqiFV42q#Y)nB-Be?&0f4O9yD7_=;)gi2~lgl5A0#Qn^3| zhjqS4O2>tE;o2RPcS1Tm<$qD$j0}8ZyH3Wlo5zsY)x%r=EMz~Yv}&&X(A$USXGacDPKO?&X@TVy zdERxVWo5pgH>_2zI5&0kx;-KPMREyt)0YV@4DuG>nFDX}fsx9_qFAB321RK9ClWq3 zp?#G@tE(d{k-V@Lp>r9Y1h7%mBZD3Iy<}TxX~*hVngnH^+)adBVMS;)=nyn`b^euy zT{x&R42WU7Wpjl92iq~{XJEB)#_M)<8cPaK1oSM5xY_2iKGbti$&N!7~ z%6BmG)9~^2Bvr2*pdZY33>+T8kUvp&^Sx_w*>dI(%}C9%aCD;?<|M|o1uP6QU+=ZS zOY{+5;O77jaQ*Z$0*z>CO+*KDk8o>b*;9d;p+C$1CGh^`D+)|A+%FA0$MV) z@x>Q=W@6)PZ3nBpkR(+?6~M7na46)jm`4N*x0>}7hJx>oiq`uks9nM?km)v_ar)^B zD*si%fNXPq88)Jm7mL89qEQG7F}+PweRkeOKD^8Tj*Q7rumTO=oNNgoHXMg`wRK*w zY5@x|mua}S@hH-VU`afc9mh`OE4**9S5Da;%J-SVjMhT}75465l#PppQ|MGv>SWs* zy_*PuBPoUgZ357>#k|`svSlrOF-DqB>+qgu9=p6Zbfa z+?##Z0~NA|R~H$7+;y;Dc*&JuY}M_hSj7iWhBDP*!t8lcAT(~fp zTia(~uvNV9M=<{Ov~BcrM%s8{4|2_?4fFWEZu_i>o))>CFDS$xg# z3xx7w^~98vQir2j2}U4@PYEo>^4NT@X^iGeg7*@T09DON?{t$PW~OdQ@5YTCP>&}O z?FszTm*C3`RQusG!ii*;#_wymVyk>2=Mo=bOt`z6|gcxcrc%Sa6hRA6_Ovob|St7 zOGlAwU`(YxsajwhS~lOIJ_EORJ!*eZzzY!2M43+4OGO8!&ZEqAa#);$hAiK(KOFYV ze8&~xPD7(lAtg9lv4HKr%*ZGSZw)OnFCaT1RQoq)3(AXw8%-5S0NN1%4|fS@yFUVv zy$^<^y2#!a7#OJj&o13Q7pjHh4g&T=cXzEytm3u;Nb<=II?*nY=&?nZ`zV2~rb(YB z=?{@mEYOP(Mw6xHZNQq;Xp~LKNSt!Q_1lQY0hVsT7}{z%Wz8Q0s48S}m^q|EA$1c{SxCWBch=Jo7H*;YD>uks2K z-wR)mPaSe`sl<~Ru0KFx_hA33nBtEY%IS(f;+GC9rv|VxrXQJj{!1ToU^ASO0Vp?0 z>A2>Fgqg|>bQXbE__f)~DbJR(8rL=+TQ4`3qv#Ss-GKi;LHX3}t+Cgx{dRoiNab%2 z(6*_A>O*ib*}0^ZyUZR7GHlHJ%G}aMoR&H4(se$gM5V_(?W-C_>8iXo5mysbhQog< zL&thj%rlAjs4D+CI-Q81&{AQC_XfTSa*E81i0SXd)ZCiXe@stJ0&= zqqZCj4?PVuKRU8rjAoS0Q-0%jU|pf-COax4kzr*RiaAwCvX8!x5`ThJgbjnO>zmOt zaw~*Ow8I6Dx=IQDfFE#i-Y%a>dwvoOg%=xLxfBo>{VG#c{p)T|h^4RQDGO;omJb<@ z;}}Q^KW;$QP~0wV7+PI>O4_0-I}=^6r%K%~=s2*+@SHRPA_f2N{)fkW@kG#==*3&` zf={%i4eDPByR6T#lb0(SoxmX`2z8#MBr+3e(=F)W>$Z<$Q*ndN9Q0 z%&Sint{yfGY`>zu(s14)Vm_lp-Qb1Fas&Jd)MKm|EFzwsY8q}#F9p_Sp1_02lL_R( z5@_3u9+ z#*3rZSNl9L@SE_dTfV_(qLmkjMW`!FYnfbSb;@O3w zbl}10>gm;jLBg8!+o^4-T9C>cZbdpdgShRbQY-ospJASAA-(r->gVY&b@N&Wz@UFK z0yJPNcnw+2MQF#m@=|O#DuC;P7L$%4j0lOT90f=5iSoy5^#NVID^8;0YeY2!@oJ6f*T&19iTAuz>7uFw5 zz`FrrE?q7ldPZbTu_XnJ6+OOEUx3>d0keWKxLT3_qS?&0LkCKMM3``bEAR%;e3DC( z=k@gU>!7sN1?bSg1i?wB@1A`^3JdYVRUuxoKYUrw>yoh7lM@GwwpQU%D1Rc)sXKX5 zN-5_iaemF?y9aHJpSF^uw#Ginv{R4&tg5TzW>Tp-#7xkq=#WtcPg`z!DQE84mZZ>o ziWg_2Y%~<#{P^)Bb?8D=B(CcUbJC<(lmF4EOnVxjp2uO_xl}7YE9v-32kHBbn-#=C zd{v7*b8m=C6^%cO*|%)7X9l@*FTj?3x^w25&iwe5@n5E;o~NO}$7+2Kmn41KZ|Mz* z_rlP7{RvTjHJ5{y#AK|^qUrxpfep=8~92ub8_$`V`_tC?qM+?!SE-wzO+uKX}w6{1wXhaAvar;<$Ms}To z1x(1WN8SUS`JWkRqrdsT;4~qrn?3ITGwf*UryyEMU2vPVLT)sM$QcANmQG2SWUR)= zEyAT)G6;K*{2{lSJu8`bSC`RVvgWs?sBOohC-^EMMLRdsPp|v$11qM7gK}(ZdRTwv z*0b24ea+vcU)NZUe`aswe$O+i1r9_0r6j9O&?$7t6H&1x*9Jf{WblBU!kdd7E!tk! z038)T4kijZY(F0gW?wC56%-rkX(*f#v@^0!m%>mEnf2ikW4I}Mho7t(r zP1K%1M5B6-jg_yiOSI8e0R0TbG62Vhm)TSQ$_@EW4N>=E^8Aq+tMqq>GOFIu_ zw#s6SWJ0Q%4y>*H4vbJpY6ue|RdT#Js>( zynQZ$wx#wQI5`+EyMs z6GZO%c$eQaAofo-+Pc;CV_TO{ttb_pr?z?Jn!Q|gFREUVX~nlU^8#d{d3ZPOQ8ZA8 zmjW5sk|Yda;J4}CKr&;zIwfq(%?yWY^Om<+bwVSH7Er}fppU4B`2h2{>8HB7Ac(^f zhlq_$7Yf8whJQu8bEh5NIgDX7BmT`hhe}oSG2zaK=!r_N9O@;et&dgq%E3gq&5nPJ zzevtg0`Qnzpno<`a*)ULRlvbM10>5@6gSXZ(8~Q>z@^ZaX038#?N-MtZPt>F0Xb0$ z+LRdQO&N8H_~N}-bHVJ&3ll-cb|O~*0F1P3*yjnKo161WzXjqp(om6m3h)HtFr3i^ zk?4EZ3wPTwQ&ww-{<$hrPhn_P`*Rt_m@jkDTCje4Q#)}$A-9h7lCn=yu&^>?s$!Z- zJCRCFj|+#}>}{zn7Jy}L4PG`!_{)N?)>wvdArmYy!%a;+u{Y;4xN;U=2=}3A^vszs z_q-b81m-60&`sbq1V--<_6-eLTqEbE;69t!N987yBms`dm*e0wdQEarSWTzJ%%nFOGwr<1}P{xb4&j3c>R*e!GUISN1>!nr+(%UmgVdi{`{ojT?6W;0;{WwaxSLGV7x z>TAm8%1-sq3&_cSwSfyFr?Dn4A6thV*UdlmB+P!Ix^2c)*-5g8xp&@1KYA>FIrtt5 zSkk|MU)Hj|?=3LkEGAyQsWm24(KpFx?HEz)e_XBFT!&p=yAycp`$(RW13njpHTtF< zW)wE?pUjCR>XWj*YvzGU*sL}EK+JlqMb188Q`(1>lzZEC26pe=Pe16~D#YauBHR#y zbQFumpoBl;O9i{Gev%2K}&^yH0D7*N*Ne*v5@55^F`_PZ&~t=ZY@`VMjIh{d5My z5={gC9l1`)`6IWGnhffJB(p`r1Vi6rlv@)5hPIN?u5Ws4vs~*`{c-)ykGrUjNSZcP!JgTLGA81PEKyaBncy^X*o(LN|J5uk_ z%4~~t!$X?#C&|LTxevYgBJwG@aKl2Q`5xA%CvxVOS3_v|KLuYT>dEMr)rOdJbD9O- z9r=21QK-tNiY3x=|6l_a=RjVFBsYCLB$@m2 zX8b~-mBZjK@P_N?6rX)2A0?ZvSk2p1vT;QOXrP6-%#Y6~r|W+1?qQ3;l0xUO5ht*9 zebS0j(MJ2B85{L7WNfrdo*o-)AO0MP+`tBp`k-w2Q|rPGw8CAsb$QcDKUTsqO#>zo zXLxp*_2`}$ptI_7T{kgK$@pRbb%LGw=8$i#23y;}60>R;9MJ=$ta={``{6kh5OKuA9b3f=kfj7ILGkC8i@ zFGUNyi_~F^4WHnuu2Wh}xwdw3+<0w^R$k}|v~r&95+UFR97_1p*WnBQ+4Dg@ZO;$q z;0x=3L*=k(3-E9lpvnkArV8&mjhBLlv{V4x>5GSjcXbS2;fl!B2%q&MRX=h5Riolp zQtt3YsM|)Bt$eF6^#NKrOsb$xZ$z|m5WA+hp*oG*C4vI2929T14XtncchkzfsQ=Q+fy%iA zHrs)YHZ-0C%1r+ztz0Y(dOQ{vC;*pO5OTpDwr0^31i1pTN$J?ZnsvL8mGI=ic9jaN zJ)Q6g`abn}uzfgDnd3?{)cyu4lsNosgB{Jl=c-4RZnm+X_gO3C;aj4DT!aqVcYqPi za}!bg%l$1sezqtb*$DTyZB?%_g(vdR1)x1$l)Y# zZ@o_^y!LrM|55uJ9Ll1Kn&GEB5Ls#N^R0S>(gZ4(mr6DV+uD6rG?4RbR{}+JEWPn# zW+oHniSDjCv&9vw&F0V=@f2yXC|759or!X7Q0Q81O-&8+zi8!6@F_s0d1>oD-KsRz zUclBf&qb!0Xyq2jlZfE4N$z>ytfrkm179~7VkN!sTwBNWmRfwAs2CjiYh18@>2e0;h_E^2D z9ZSsj-uQspr0j%EdgoE;*ALJc?KAAs!mxQpA63Q<)H^Qn5ORFG!Pkeue|i#BQ(i0RGJ zp^KRzQ5UQBYd@RIU5XT%w`b+8J{w@uOWW}Z+bmK~(Qwgz3A0b?3c1x}!Uzxm<)bB!}e)d(m^CRQod1gyCE6gYJ+;zfJMO6n$ zAapR7iY=%13#TMW@CEFQUm@6#?HI4tKWxO2p15-L>pY2-E z6R#Gtok@y%3~|TEKS~budGBWYOmMRXW`(+~YxA)xX?sjhpfpOgM+ zqLqMQYd)dy_K@m7w-HaCKM@?^w^FWM3hGdz#B+L$#XC_@0M&oF&j^LNvkKvKiSET%T$1k8Qk72Y2f+p0!G`R6aa<*T4V>ujwR9Q-F z;mktJoEe~2bqGIe%z7e%IOY2hr#!UOF-GU<;ZC4vAx`pOKxil}oCx&_|7tqRa8Gst zWBOCiZch2TgCBH@^BZMNAs9O~23}!gwHnGUoeBb%Z0m5wUY@_wn9#7qU_Tfq>|`1f zo?TUA0U$LRJq-bg+p@k0(Q8I<=k4a^c>tcHp}NuQaSc{W*530^FXOL598NUhY{^9S z?85UtYTn>Hpk$cS86a&mWMoT4fI`3jgy+oyMC6X!`OtZBF8|4fgC)1K=}w@Pm|wXA zljj>wh$5s1^+DWhE4K=WXJ~A@$xa7M$9dSXIi!?P0@dW@XVh(#jgy+`?#qR~7zRsnqi#hzXjb_^!Se_vo5p?8h<)o= zSH`)l%A)y7O91V4iiLw^MATL_m$z#*AB#`u{tqXAnMdBe3Je_xhNvb0AK>1Rp$1o7!cx4@5vD;Cr#Vul}xiH?*Tw5zz;D+0$In^*JsP zx$W7eLLW?w)kiXv)`>X~$9qZ}a9@Bbhb@hBNw$aVKqnF!NwR8&<2?;{yvS|i4q=h} zN`NYlbzc?-*J~O7lJg7dQarjtNo*k$;~PyHt(uW=g}I)2;T;I!ER}@|W(lU}8Zr8#Y^T;j)%K;pDjLfx z%R{kG&+y1Yj=LutthHQ+tNumn7JBLGX?e995vi|4=td51T? zIQ;1RhEs0%n$%9^F}YUIyg78x-z}a=>MG7l;FbC6&OPp+ z_&}|*N4qs-;9{jkthim?YcAg_M!}0@{^~vk*p9A_zGvjgx<;jTWvb&HBDd3t%Em}k zUv)erX25&Z6D&pe{>#gGGJ_wOhoKAy%uAK*CM>zEz814vma;eTN`5#!-PHxjM*D92 z-?@q7Z7fRXdo)cJKactrZ(1?wk4rF8E8_FFADY4vR7>TsGau*8pfGZ+AT51LD`&Z} z9#mv9*2^wXZaaR_Tg`7F0AqToNV(&eG3Q|>d&6vK2N(egRIW`e_;E+ z-fN(()j;j(qZdkj>UyU+^qluTn$LSop?c~1`{2=^(jM~}+O;ri9g4Dy>;WL@5=dSY z3@_RDa+b`(b|*m{W7b2o-o~ss1rthscq-$JUqF5;++N3cPR-vSo))qP`MKmJE4imy z@c7lXo>;j>OMdVhY|{T>itYf;Bq5R8Q>ook@0vVTU7u!3@ATy-YHm@?{gFNf3!JX@37t`HWdzl{#`_nSG4Cf-t2bVW2s(S0=dY*OpUOHAJy%(OI(KFe{L8{2uyq*yhBLO zPfc`TG;Vb}wU$_)Y)6tH(x` zfs@c${Q}LwPBH)hkHqPVu*H`LTMDFrNQYMwhG`0DC=)l4WYvElz6%U?AI9PN5!6&% z1~afGfyXP6N>^y;vERU3yqej(0{SyHN1M*pbJzFYD@sKG511Gp{nkNz!Yc;J>2Pp2 z5RM8bh{If9!#?lySEmsQR<;=&p0UJW32dz+!up@TB`_b730Cdn6n1++boqksr6%58 zXf}~O7US!3{)g`+kVeoPNHYR%Kei^f<%P+KMlj&Gcfm+yG*Tq}EM#tuy8A)kvGj%5 zT|d_U_>qX)8buSR08dvodUmAAK{5fjp~EwE;jFA9v45s`acpo#BO}Lj+3Zb|QU=K&+(<1-8UcGFq#Ik4hsq`*P!Su`hXvn5MbbUgE1@RDTH0}jv*e9wjZ<-*99jaUWd&_$s zznyuR*B$tAq%`ZB$k^J-Q!cLrM`!0Vj9kiQ`AtC_(soc}^F!RoBMbm5~|MZEG#;`dTgP_bHdQ{lu zSB8L2J9b8uE7RO^%`lB`Ku|1RN~p(X0T(kEB3o=-s9G(mc&H+%HnDivChVrr`*>ri zFiFy+`*=6H3QWQkv((yOsFD*_>dPsGMNE`@5{_M({k~uOyZn+@`CBKIO3O^>*0}TC zyHtD3Lyz~FTzQq>ak}YA3jE#59+l27fw1ilX6U+4xl(k{G%c&^W%WKC2mf6UTP5H=r^wb+w)64Mjb%VGrSxOw0G&OWK=ltWO@AmWwM`MzDTJ5I} zYVh<-5k1XT#A+1LWK*a7%>31X;+Re?|;k*3ypMMk%!bsGMS^ zs6EYT2CdFav(RyIxD19tJ9A)Xg}J**@m#CqY(6cf52C-L3^^trXn`u!pc$pJ({Xto z1i53jC%7|^0PpZ0z6uNwXgD^fy3nlk0;d}VKkfvSD9YG4g#c%}6o~!Wy&Xv5(+sE=lU%%*rr^=_!JAI6;e{2RckE ztOSLBO;eyR20tEFs=AlA0g{oKXV3CGzFq~IZzwd1!xcaLAxbE>*==j%AHZn3E z>%<>@)K{T%#)mNB>&jPu_wQR8X7x}t)>!RA35AWfN8qNp;e@~yo_BAM-8}$( zZz!TQu^380qUZbUv%KwVh?fUD{DujpnI+oM0p$HOlzIz`Ic0ogz>H`~2 zwa~dy@P0+53HK@$sQd0v`}kHI2k$!!Ng`6brWR8rv(dH%igyGzxXe0`)+Izl9;2if z*x!%kAI9f{gu^f~R&0IH&RG4wp}4ohfcg)`-Q)`kqPV+4E6L(xtELu-hBA0ZPEjGF zGLu?0E(QZy)cPTp9Mjo*YWB6mj@r}2+nFb-IAW%VuN4U*1O)_ns+ptxyO-rFH zBEIXo#ktUDou^egl+J3Xl{GLiw+>`oKh84E6Cy4CRy|0mPaYHMLeN90qR1%U0hYxI zvHkcLW#5Dyek3+juKVZ2?)EuMJX1&;ndjI+@iZo|PObQKwTZu?U!qdZOrZsK@REI>e4^457$eRI2RY?5C&TcN(xpgt zNMbT>Qu!ffg6$}*{_LMWezxJ%MM=H+_xj?*4?R*!jhPFdLq`IP_p&7H;YN8(QYGQ_ zLg)E6YC}K7@tHCnDD$cjc>!RmKZKCg{ikSVr8dwsg7H*gEQ0^=8$F~TO zBuKlzup{H#9T1HY4xPH-!ySH0wpHraUWdV~$kPs2a9GtozpMj9CN*@N)7eg-1}OKb^DwGTou;u!aM&9*`%#evz88Lf4i@sU2#c>1xXI%+TAAhI-`jAf2PhL55>6 z$#Rg4hnrcK2ShDBsQ709X|B%T{9?QU&XpjSu)8o#!gAUj0SkyW%#&kW8vr$U(rqaX z&V^P;@kNdQq%LG}F`#H_!1;gLAF%}EymReGU!*P1?^@3IZ(sHgA8I>c9TFX4!=McX zB5CEwd&G@}C~M(;n9-|#XsrDiG2L}_hgxjFNIkU*7JR?aJ9GFW_H4%>kB+0J0ro@? zQo!73{t$f>6_U4(oZ-IojPP^#?zXlc;%VSee4pUC^La!FMq-!-Ul!~ozpk3VZSv0; z!qS-Q$BNj>NY+z!Xc%k+`rM*oXID6O=0pS+Yw7PBcWD_ZbGVr3eH5~KgUZ}kq>UD* z`UkX5P%7jWVi_#t`s&Ixr0nGncHo#`m6z_QQ3lOCkd6ov?&IwPu{nWJ@nGK^oBVh`@#!tAuHmbZn+LX!4`S^Z zOC4*~i>uYIYW3##>0TLUjwLlIWmL+Yj~u2H_kEP0cDmYy=l!%G^f1`(%$E*EF>O2< zyOlxUzVB!z{sVg)Rm7{ObvjT#i%k#zMP}su5k9+)LUqdAu;nbXXCu#Z@_QIeC|-3e z5Ax^(oYs%bq%=!L!gos9yrb~6yH}cFe&ym+9?wPQ&d*t~oJ07usNTVo_2w|iX^t%3 zx(7k9514`aO~L)(o~=QiC)*9=#fTGn5=L-_V|T^E0h7xxI?ApG%w>wEMTmrW^yj?=g; zfXU+3OYpxDt2`6Qh9Q0F@CU+MArgIkxSNx{3N{PGCt_S77fF)m0WI(eulx!KlQ z%**_Pxb|!nPy(~m_~N-BRU_(Q!|~M_gpc!D2NPs0+A<(v)nb|lQ%0`;MTKN9k-U!7 zM-ci`+zY9PkP+JN0p20x8upqo{<$!PRfhy7dbA>EbI%=gTjc7v?NB5O;w_qLSVNu7 zT|Lsk^_2-xlTqvus*@4M>YebeeVvQaFY}S+7jo8|V|%Mb9s%{a&3n07yH^MhV4((? zQ(-oU(iWYr_aZ*M_Z)nwI?zE_VE7P>R(F~i^8I)Zeg6CzW!~HHu7Q;1aI6k6_X=~Q zlaR~wz4%F^Gl~@i?zYO+r6IzWW#kHc7d)|vr&rCRd*&zdIR^Zjw$puB9wJKT(Q|U| zRi`TfSb$bCNE16xth3)T+o7R&4$;eYBW-==8JQX~8=)r;PX2>GbQS4)0Z2~=6{SpS zD6qUB%g)ZfOjiHM@%(IPt4SDx3V4^0Em@yw-0#MbWtoQ`V>(g9Wp)G z*w}iR#G(thLDrY8XzZQSLy4**cjj3c4rq2{HGYAmn3b+w7S52rKT_Rso;_RTxa?tE z96mBY%>M`zdxGfn7}KtkF`9qq^mag}zo?DfMW+X628yp@{2tyC4C1=H?fV&M4%A&W+4b!dX#;BP6D=n03MG$^CSteV3tZY0+$H%eds(zATq~S8ZyEfq;>%LW zd6)L`*qfh;^<|u5Pg!QQVE+e2h(N&BIT-5Lz10GWf}6c?{FHSOVL> zn*%h|n@sr}ax9PBXLxF(jmar^!A!iS8r~B2#RV_WI}v}^-dp=Ynsf@XEH_d!r^yj2 zwzMl(4OzE~2vG#FFdK(y5SCFskF#IGQC`OyuG4Q_Ry>Oa>LMonA_`X{c>xzpZQZP$ zOnB#>;Z5wW?&10&aQKitbqjwXE#4KNCVJO?Xnn1?U{{eJa%OLEa0^$iPli_hu&>Kk zv0-GzV6FP5$FZDGR~faTCF^wUD{HuVN9A|(=M%RaDT-OFc8)>$NWpyjyZOTp-_4tq z4|F>Shb~SJhAV9H>5H;E1HRqde*37%4&rLDWVwvtI)##ipqLdUVxwy}qos?%o6+KC z2NTAjI;AW%CDa=s$1$*BhZSs~!DeAyJp1fXd)y=H+M&|RI~Fdgs)M*W=aiB;-mKtPcJzJbOiVVD3=l_IEggTbCJ}0}IU7pDtI|K&_$I z5-%LrhvKTB|D>2&g`kjw!xPt0^|J)*N>*3~Ii{%&D-)m@f8?Qg1yD=bim9x@dKeCR zSh4f1`6?t!1){!HE#L4Qpi+M3Ic=S(@T!YJnHg}XmnRi(>(#$Ocxkd$rJLCWI_{g= zkx*Yv_d!lOD@;i2tH%{xAyT8)3aXu9pcs*G52&yJiX_> z19b>b5ALGuPk-5GxkjJqkatd3SNC@nwlpmEidiJeNwbC%xE?AwM!@zKfXJYlu9$Wn z{oVbsy)rEI#E~`-r6Teu%uEn zc9hxQ$~UuMN~LB(-~pXU&%=(>_&gTR)D&`S)Qgoe^yM{~Vo4=Sl}vPc7SD@zq>tOyW%=Y2B%5L=di^uy?%w{HUC@1_ z;&zC4U9oKb(}F#A1Z+b=w_p*%t8U?;~7MZSJQD!SXiHt17TfoW1oOd3q&# zK0j>wZqu$7!`@+kdYGBgPBJg_MUuIsLq=v2NEX_3p=BPAZBG-O#I57q`_+WNvR|>aXO- zaT{3Uv;(RKK{fHn?72FHfqGRB#24^yh zw65uV-mAFs^Rr>;lO1v*JqoiM$-nvU9&3-R(_x9}RVXn3_@aCSf8D6n-C}C1Q1}mz z34rJK0Mhzmb1la%>g`r7yLrOULRf*0cSFsOPS5OiWJVxe-1M3l0ou;*|)oA9~ph90zxPyf>iEwKfLu!6)M_4y1GME>=Tekkef9gnOq3 zZUK5!E+QVs<6{(&qCHVjXOQNM*lEM;!IzJpmD&Ee1nXcWRH`dG<=Y9dL-12cd7CE( z!rzqXP#H{;8{ewH3>hu@R;Zh-gVlLvIC%G_UN>q+N;ixI(E9?U^71D}l(MQAok$Ohk!!wW6$64SkMPZ_s`E$NV9KC6#stIWaSb0oJe&F)cg=t z9LSlnH-pAv`mm|a-7klU26mD01S23EPc^wT{Nr`1VW62yEuGu52d-Z=9C9AD&jVGy zntPs$&A<|ie~9$iorg{k1%*j{YfKqCvHO`kp^ihU4M zIAp%-m#y~zes2+@BINM!+@yLV_3`hn!U=pFBO{}!E1z`z_DkUsF%xot19*ox|Y@a0_w>i_)T^}OLY`2YL@p4xTt7Of;ZpZCAM$p_Cp z>|}#`q8a|DM==Z@a3;xearN!j?%&Vqi34z?{eOM|(GPJl>3$CcLAwU0nl$w1rydb) zd;Pp_-hqk?#bg6@O**lD==w`3K;|(iV8wMde<-R(c$Iyp_C)y!v-_AFvb!|ZSI(Es zE9-RfB)8+GL!lyi-UgLbc~klMLzAo4ANTM5HDsP-{sXh`K;~iG;1NE~Q(TEb!LVu; z_jO>+)!$(a&(`asw;+GD6#vAPZfgCN9{EaP5v+hOJFy#9R%$ps%&&-&|cUXA^huE@EkfXn;eaF;SgZ1h@iwgBbaq5W^1TITEaO53`LM~bu6iq zP%5BQ#Q#pzARn9<#Y*{sU2K6g_o@jRbNw|rW>5n(0^0B}jfA~syB^ImASL4M9DiQ^ zf)Vy?GEh!~r;&FPO%LAdaPUA0gDy@N>$VVq)(tjk_pjQ@R)}RH*o0x9mlYZttHxSb zCT{8BGJbaC)FnZfpAiR6!8MWSFx{7Hj^(}m6;?1Vf|V}KcCuRuw(2P49!u@KMV=o} zyh@SVL;bC@f`-nMqp3FROMTdLXS~);+C5+si-96hF5Pz=X!p=;=u@~OgE;I*a#trP zza;o=`*9H6nP7ZG>dWm{A}7W>a}wZ!*Vb}h8s={l`LF9Sy5QKMM)rqwVA#QB{kZs>cScA^!RjQ7x0pWN1 zOU63muw#odaweU@G1XhTF85>*Mx9ZFEe@BA-IF9jhd?&_EaU`hvAv}m1|YCZS%ScF z4hWyA3UUV5zYh%|sQKE!j;sL%!(NzD(~Qz6ZVmpZ;tqxw(d*nzgkpG_g)(sB%f7WL zHHbHe?sb{E_TQ=}zw{h;cZ*t3IMwzz7;d4B1ZxSn&v#%t*6?ALB}%GWrOkV0f;V%q zXQbho8~Ay#X4ZT}Xis|vF1N5(hhZVpv{N>WN%cJPX29-2X;SmHY!ZGe8AL*sX&@~% zIWY6T*CwHesWk0kyM5d7>dwl>O9TVu?@jnOs>S-mB|b8b5~%RNoTz%=XsK1`kA_!Z})j zjY!w|$9~HJ5WtUg06V;|2-djqyZTzM-&M02w1tW^ytZ^DNwm#_-gbzC^)>fd0z*+v zT1pH9x%ZL}W471nI2aUYJU&jzLLY4#1o6%Ia%11BYrtycE0N7TjNi?j|B}pK?xNt` z?UyM)5FOxO$u=3B;#>cCKynjq|J4m+L+5{*O9rt#onRWG3;G6Xvr$*;3Q*FZ>wK(1Iw!MQegIq{JO+m5w12i|dcVx6B-}EY zsb8AjAG$cZW-ar47up{2e7Di|S7}xe%4j#*K7JD%4~8@rMCn9R{mnT}cEz$WLTP>| zG!(JFA!}MPmbXi7=vR@(}W@tlsKvEuqSj{N-IXC&9?}HoPwGkN|7sgg2 z5h7ml4df>L7lo&|tk`Vz;g&Vgc1C1kCL}90zA*>d}Nht3+gRio;*A}1nQlDYVum54NP;n zVB8H@oY5G+KzZ2ox*;P{?%t&rUJDVXAmTutc}9HRDX{DgM0ug5ZZU#&4z&W_Q!u$; zM(#!si)5f@qlM`O5Qj=`riblN7WQ~v+O&Dx6?&ks^$G=f33gv+t-A3Qz(zfjE=6D- zzk$g1^1lv^vRO#uljxFaU&z#TE`Dxn)v4+0;4ItfbNkX>$s`J~fgI=l`%hDa`)qPw z1SOd(R2`x+H7X^IstaMVGpQnRVtuKs5uCJsnAkM^U*(Hq-ugpdbt&(>rp7t@Z4HEQ$9PSY!lM#00Rri0JLNqc^D)maUgo7%S?3%#j)hL0`Kr8-r; z?3!iRw35zeY4%mvF`pt)HNoZt0A^HLS)D27iZ|@Yi~tzO4S+qC;6sDFLjfHlj5Svy z9!J9ftQh^kaO_*0IqBN29XZO-GNhT6_C7gL0Vd9kP)D%9HtxusIpV87XJ57m9C-rV zMC3w+4!fuF)2`f@J-E1hLDzS*k2-jYxycY%kv_`uhL@OG$*kNK3t)d!LIv2fnv2io zn|zT0aBndEGs3Nh4vN(_)X{aX-I~}c-7fGU7jq^!u1w)|zA4;P34|N?*MC6v9b6}& z|EGW56bhfSDFb8%QRgM{gZsY4>7BtU>rItN%16l=|Ti{xUQzLOy~T@r{I^2OGA z(8WKFvg=S~gL#!jMsS(QFIey!6h17$_!9x>apq%p272GeYXAbVIn9f~-+G@xsyN>n z2><7La;xyYfRB6c|l&OwUTR=%bJCoHrcxqTB z14tEk&i%T><`l0;oHC|7P+ zpgqzph=4HDI@AKGP53`5>!UWz?Y`IHj-QUH1I4q=EVRJC*wGr(Ed)HZ6w)U9j1Cxk z_kzOA*Lq2d{nj-sK1(zmm~)P=ypO>=s3&p~rtnR1c@3=1!=opMk@aHoed@8AR{Klo z$I#>??~o1ODb7k*eQ3p4XbXgZcdn`Yz#2P*TGR7FHY_sL#mQ9d;uu}Gcy?x6nc2jV z3%{nWpF5Fa_2&Wf3^{klnvDe(d>RE7HIQW_`r$8dR9WUN1$k-#11dUF^ycrXQFx<+ zVL3|vtYiP)#dUagh&pManbX9>m;XueWhtrAThe>?YUJt{GT%V0@t3q)!-YMbxqyI> z|9~OLp|0;9O&0k3UE63*INlPmbB~7PcBk_qm5@^rU;l)ZS&atjI{9}e{~vqr85PyG zb&D!0m=Hw;5eb5bNRl8ZK|m!IpddLaIiut(h+!i+2qF?B8AuX~R3?JtB$N^glq4BJ za=c?AAp85yciV06w%6XDcYe6{=18qoYt1?47=857JK?9?J`c!(qdkKC1V}X!GCEjT zS*tCp(iR04Owr`N!>9)3i1@7g{%am_)FRvD|CKwPh0OB^HtB$a!Px+hw)G_DOwmrF zdV^+q4j_QrEK+&^S070VR4J=!GYLA5g*~EmIrGAY01O=~7DuQe`#$aC+&`#M)_zTF zHn@3u8rnq;8PIPPG=9U8f>)0kQkz2UM8XS+)b$kjYTe2_JDR&Wes15R9Qs!Z|YNJ$4xN>}&AQS7``Hdq>=IW47BH3f(9&ggNMJKVdM z6i>)7-FWzclc^-PqDX^(d!EHjtwbB~BmcZzzI=$>nUk9DNs`(sS3%;1_8Zvy@g7ynL= zJ$x>y9U%30s~FnoZl+OUo4&GdrTf zTQHhSUt-L{Q3!_4XU8HcRXz7_Io9r@W37SUZL@KyQwNdBdoC$t^bg^YoW5oY6pfKO zm~9DV*1gu$)P&R!f#A#CgRH`e#tJcZ)+NaIqk*ZZ@Mkypi}cSZ)TDg@B2^zuuG_1F z7MkH021O2MV7~Pt_4{g$qMFeguXly{a@6COkD(r@5cigqtAtkWx8T<+C*#(2QM|Y% z=x=|Iu-VQxq{Q3<{&E=7A4tC%fQ^ut7!D9jx>4FkfS{l5cQ$EJnnOLS6vq$lCWG@3 zH2x8wI{&(4GC*+b_UD4tB{*lp0mJ>L3Q;b6@4iJ@@F`UMAv&;OTS2v?t?c&*6O*z$ zHZNVh>g;ch`+iwe+KxL}ovF)ofcjhsT^D!g!JVr7sx4a}|L%ykeU)RKHmba!uHEq^n(<&EHpDHa2 z&qaA{OvUw|51bz*0-qF>1@QW*61_h{Ab1&b2Y@rmK+)%d`*c4vXrc<4^%I+cXA~4V z0plwnP_JpD36{cU_1SPCm|#hQ47*MNjS)!}K0wZF8hI8h36g=~NX}tV_x*E|uu(Kd z3ZBahRY8BL4T?{;IrRRB`TTFTUzF(gNWV(LNN*#E(N7Gsf}p$)Zdx_XX@%Tuc~{%7 zl;b>9q&-g}{fM226{#j-Vn9TKAhfmpe2QcCHZosd%)Kya zC0bCo&H6R_ImK)%Lak-kl+Sq8+r2;q`C0^Qtr#D%UhZLFkV_(VrQVn@5P*Mo2M9b=a7()(tu&x2G>GVof*wz zgZ+FT@@)apMrR&}z}p`cr-X%Qnwey_6mfren;|d$o&K^#ppnEiWt9LLRzY0ue3#pA ze-neBQW;I<4j_kCg*y5GcNDLIoh@&eP~wM1zCm#Im_qbA*s0xmwo@V<`Y{PN%)^H6 zHEn3g*Y2ago*+RA)XuexQ|YgJ5#6qhpMmAk)2k8b;A)TXNRY93dRz9a z(kpsUkqAL^3c&d_fMa*|C>qG{WV=D2r?GQQpJzl)E@o{-#u_5QTr)O^F z``B}~=1W8b%u3`zVUS;7eN1+1hazTadz9*=XbfBS@Ia7^Z$~v`Ytxa8p zhV#x{+u7;_X&Xd^V(>;^_JEu`?fG3&$TXB_w~_JZ z&mxCK`xGgCKv@CuQDpp734V-PL!@Wkd;&EP!AG~83Qk^mSb>xu`8#fVFLkXh#}*ZX ze;|UdL6cLtzb={;qyQY|*OcTRbNqA$`@*EHLh?X@L43q`HeRGz)GYQ&zV?QYa#AXLNsI)IJFdOluR3IVYygtX}AHRUhL6ikArpMs+Ng8s78v_xlBZ4Df(iMBu#k%VqZeYc)O`dSK>Dq%zExShX^ryPz z5kLM;r-4EMe0H!@9y3^d9o$&=A^xH(2OG*vJ_^rvgl2?sFhe&Gi_FBVB7XrY`6NRY zrWBCU5S+7&J%|zy=0AYE7Si>A2Fv#oeprHo<|`M1qrweH6fQfjcu7=SLG8k(o^=%% z=1fqyIAUIaO7ZYLTyu(QJ?C*3YY-d~TmhQ)73WMw45=Yvub zYiJ9wAR9N_$Gzea@CE9fIUfYygA`YMh7?Pud^t=`TuKaI(OUv*BIILp#Tm(E`X8WS znV*M7=ADQHcDlis-pxn`D(*IB%l!_qN3}DOq!YM1-vJ6 zzQ~6{wOTOnrOdS9qxYlu;n&%`&y=1|?hZ#g7fJtsrMGkc#c!kZ5c!wQC*amOdgcF{3@`Kprw{&B_rPyo z-KERG*ZMJ|22&<<(G^H_s1vCUS$ACmy@lL?PCx(S=WY7HXyYu1YR@afSz zB;_j9&nip2`4&7y6_MNRW2?)Xu(3=edREYLt}!d`rd45r-h!Z`FU;)Brhhc&t$oT> zg8R7w@~$UZ3v6HJYzSg1{YOE}p!l77w|1b9S^Ae291ybLv6+SSp?NUql*ji^z&EKz zs!hj73lhJ8Pe)moFPEb>m@EaM7(*tq^!@omAZyq~6giG_pb;}o4}P$aPr)~fDK?m%7e&j z#-RB7y?samliT9kIizh2McJGC% zH@@0hI|_xsO@-TWOluy@(~tsK{NkA*S|1ox%$yqU>biKS#QIgAy-_#h2ik(qKqkR` zVd_rba#G6=?!%{LgTUXq8pMv0;E{2>^%JY^&OdO2fJ<3FkTib2C&g2=pc&YgR;amnKGOe+Z$bmZFc&DV-~LWjE(j-B z4Mu9eVAlN)D{k#---0ht?U)Ikc5VFioM*Wqtv2`n{7}2lGes5*C|eci&kmS1eK>>q zYCOm62kjv3>8L4>Z;9KvxiLJ!bQpp{F6~R)M(7JOBr;;wx4Ho2lsod0=Z{HeGTmG_ z+S<;@Z6bqRdiL^(#@)t!Z}I^W3jE1w7VeOe`lRTAHe?<)**T!pUZwdBh>uy>OV3P# zm2v*rd*p!E7+Y{M)QaK#s_5-`J68S@M_==tJ}Psn@j{FrR_j@q6ZEt&>DRgqC7)>; zxRpPP8COBiX$ULC$Ua?u8IKTR-(qo$K~hF zn!7^wp1#?UoBo~J_J5+09F!tOyt0H?Q>MQU;a@Zj89yMBkYB0x{~wa=pyNBMibJ^= z8B5w-SlGK$0KPip>I@)B6X@i=R!#{YSP_K|BGRtx_kA9%>9nO+5BdXFva+&}zevH> zy-go<6&^w8tL*DX*N-5t{KNm4evinj_oaabKv7IUhuX-up*_;yAU*y zd+|8r*>XVL@~nv#R6Tty<^%_WH!*1l%uZt9pS_?p6e`?lGzSs|VG!u6UMVCX-QjtI ziY+ySlpBOcpDsz0&2m8RAkwiZy`E%{)1Vu79x}Fa9OaBL@Xf%qculC%_6*EawH@vH z?@E=dcGUm96Q6GC4t;oJj;Q|5tzgdO{!U-#xY>pbI$ zWqQK~MQg=N1(xG4-ad&^zI0EftS$?0If~U|)gamhD+4XD?zaQ9u;p;o0IZ zvtPxmthN>%KF9LL_D9e7+XFlHpJUkjc>9st9X)5_C3Yztxmv2VllSbv{>P+y(lqR|?nnN3oS`P~~-0U-A1{Gh%%MSSf1lb9AKs8EVLf%29i z@f|SoaUES?aWrn1(0&o!1d>#0Ad7ekV)Uy+CHrBGst0BkGt18fIY(y*drsdG_wg?^r`< zXK;;YTmWS()l}fzXUb0mTv{-3yelDWt1$bMeL@f6@gdHp#m8t5tzUclPP4rYfvJOH zB3a2K8SB@l|Gg{3;hF@E)|$W$;I#3oH-`;a(L~PSIk+N$sZ5~35qRM*wZ*IL@ao1b z6(cSB$zT_o#g?&&}GO+RzohKexPX+ivy*Y4MsC`QTpCX1S;ps8uXg`9UhC#XY3OgG+$m+a9vI;!`E1Cr&l zw^%nW8GtVN9LO4ip$a6&hBXTvM$bZGkYa7y_Ei@u5{`Pvi&&pw9h2i7AL`ka%rFZq z=5Xi;jAy=^>U7~G%n*M4d6wvlz18;g=*?p_HKRK?#c1yP>-gF-+`f}ro1Vqn543qg zL43VDZ5XI7L{W20Dj$)?s>(1^b0-?+pXJ31AAvC+2242HAsoLxD$oi6Xsrf+_tP4y zYCB%-qvWs~)FbpXR)59jAzj{SLzy&Z5@y0?!qJxJdPn?86?q?N1iG(#H>M(~%yZ5X=A9_#E z!*O&h{H=bJ5+7*a^4&^Kk`VrAjr@zWKLC4w)Fq#uT{p-;Bx9bZQ9&$6hMV8AVSz~uT9$$Y36_>mcWtWgA!9NFTNKah2?g1;X zC^G6vH#Dq{2=BN~(hhdJY^qv!#5Eqk*}(?&UldCCYOM(1mZ}aGx~vfgP8^^WQcH)) z43o&c`}Xy{ctlH1V=)3`43wole*Ab%03I|lupYe?!HrW3SL-dBT|iY%{=kLr+ai+V z<1d5!0}KBLXBeZCf+ne1=`W)7_kU3IX4JN2sfFNT%8f|CzjKm-=d>zQ!dU^^DHs_+ zySR}sXyH`22vzM3G+W(zcXc@d5V>B+W9@jP??UfmP5XF}GY?Fb4hRg;UmQ-&?zomp zwKzrGyFh1%OVl7?oNY7bM5a&3T)YXqP@am%A|VXdaU3$_f+sK63mR?VDd~k(ZZpaJ zMRukJ{Z?HVgN84&@enCY)slM0m30oTI#zRO;_ibI#!91Cj*9`f2P4X`@DW zu_UO$dC?SZ&E#B#SRn%0 zCNaGxS^vrXwQwhpxyRe@mPptuk&VCD-~{*N&(FBIEh?^hCGP*DX$X-A?fTRBVD6Kf z0oCBuw&jZbmd74H*ti(4B8sFS$K`**rs8B1jI2k(IF%cW!}_xkS8jy*?K?l6fkoz* zCLg{2BYMhRA@BlXW5>(WUP{p{z5%NQr>sKmb2;~0g_r} zZ)K=?;#W$Dxk(7z%7QzphVz*7Xs~(=>cK##Hw$@PBkj3npIf^&9=oUK?zPxMwk7tw zVZc}^CEdxr05k+NdwZBh>4bpYt9?3PfHji)RVNFq;~KY%<92FAOitRMyc?;WOoHnQ zO07FwmXnNyX_2rcwEC-%gg;E$-EK#oXG}~?Szt7XJc|b8YM0px$+TsD)=V6J$8J%R>QrB^)hfHSbA|>X)gJr&b!Ys8yFC6m z<(i{xN!WbxNrFb?PeDiX(8W3@*PAiN3h5s5vHED)$6-uO2WLWld5ZSz%1Wo3|48DW zkw|%HkYrJ3|G095anJ<5*ge6LET`AnCpWzPxb?El-mEd$>!bHliy$do3Cg82@h58{ zxYVe3q*>A?M!gtedPH+Nm}dk9m;G~`Ll>AZKav6IXeK>32ELFsP14Nbwl3gcZ2*wD z)xk#Uo6Y_%Dm@UUN18yQ#T^a@)TeXI$iU6C4QxwDZxAcnlB3=>o!eZ`^g#sJJlRFd zXN1>ZSE4+9n71EDaC;txgTR*H&M2&dbJK`Gv!0qq-Gvqg5(vvw(>3BW7pB6UxH^-* z=^W=gjZcWlJU09ixP>Y<_`e%yCp~O~ZFX5#*AwE@gHlz(RI9%piSWM464+A>4pv{| za$d{ifcu3rKyOLXz0KXw*3XdO+eg86I|Ib)l=X4%8l=q&0i#0J1;Fsc+_{rkbYuZU zTJD76z?iE*UKTJNLLs;f71!~p{5`7wX$Z7noMLolUjLCO$L>>xL3e=7JFPMBA)do! z__ORW1Ee5oh~y-=2cdGqpTTzeDntMSs6}lY_X}@c%;7ILKFW;Ocb|PYIaYa$^`ZTG z-0>96+Xq)rd(WdI_$AQqWeMMUJ|Go=(i1T5p~@~6IQB9WjBa5z)Y}Cnq!nuH2DOk# zY2gS55K$$Jhiij-NT2iBk@(AXH7ti5*f_B(93WOrXa;#0xY6aUZ*9?jbpI1i^#fjG zB@inI@zfWt8rh~^VztF>XX*F_Cu~7$94rAxDj}!a(a(!{j^cQZm@_xDzOgnvDnkO~JqI8j!Rq3Q3;1BPx=Ww>x7aV13!u#63|B%ZO9pn-esah=&xq{0niH zr-7>|>~Bw2200s=vG2f9AxGGzTq}Zx-NPZ6_eGSRV!~BAN)D}on9qfu^?46nTi^Py z42B?oQlHiNF|C)medEjSzE02X?Wgfxc;?%cOP~?@;f;ol4hQnSiuxsP{3IY3F+$UT zeI7Av;Q{wz0JVWl=^;s^NuJQ&Q-&shlYTc{^>(_q5vb`KZo&g3O=dW`5rz&v6>TG? zLarJxy`_BfXf*VhQZ;g`v`=fhexl5{j$G9sJ+1YeTnOVvI0j;PUyL6TEUt+k zW-3S1wn9@4SJpHNAo0>m#(b8s{uHr~BR{Tl1TsHub^EuOfE|p{hKP%cuK}5Whz_{m z-WyDd*IN{KKW0Tee47v3N%DD*9z8nQeB)XrunzhlXA1_DM49Zz#f|Um`2j5>*aAEo zTfk~}uegzHdX@^}qLEwo3)aSinKL>79LcVjd-(fvva*J{&$LR~eWDaDc?(GX9Y-!u z|7qkojF_^H6g{BE#RTwgFoqMC)Ag_TcE(tnZ zQPdc%sk+K#pC?J|xyBy-NOF4E&y#UmvpWXOlSDf9;;0Ao{={8tqq$gRUkW}4 z^Wv5|MHgSqQ(b9KoZs1E`3lDUuq6FHFBWf^WQSB-wJQt@54YmnLna1$@QAdt$fP&f z62EQlW7;m?AGo2({oI|;u+}-FUMStb+i3Qr>B@`Y2^tP5+Clc%a6%JZ*36Ey3CK4B zFZvec)_U9VgyDMMmYm|A6BpqQxCI*+8t^o>K?;KY=xpJ^43~1Yj2j+%DQofO^apzV z^M2)*ay@RV_fTtg4h?$OXjUIGSP|>DT#irgeCQFwQ!yYZcPdKoQF$GWx1Ya=!7&4i zV)ZwS1Cq(|7uo_1f-)y6#AYOOa?t(xN61*+&*h{rAVAtCL#*y--pamw(p;qBI9+1V z^waEo&2}Lv{wu9i2C{y}7p4t9*CA178cwHhXQJRe=5CzfVq~?=$F1joWtJ`wGd47{ z$P>r*RpZH1AZs@M&1LSKp{dtG7z!_=Nlgyd=ddv+`x267<+5F}SEPY0k`aP5SQ*Wy zzIBt()O*5I;WX^%d;SvXBXOL%S^y=%6rNbWTjBQvsGoNZfSep^g|$d+#e|*_MRo)xlBs8>Ic=W8rH7@e`kGuG zc0(b1XgUS#SbtDzoSRG;?RJ*1%Q=9rF%vI3*dk=`xusjJKtRxd0d3$_^-|_SaM=Uhj zR8rL%Kj6yb9gxok=~o?BOFK^m)qyxvxR0DJ1uxypBRBDzH*k{j&kY38g(=akKN`0^ za3*I0I)x&&BiUt|n{JSqb*SHiE^Z)&4CjkZODrKiYu=9E<1y}!e zI>-kS>!1_S8*dI$&`HRwdrHhC(4fhf#Yfm+1*CG%c!NPN{&U)wqb^i_D-jyZ^7F&G zs~S8P$}+&F$*KU3ln&gn=@>s3>vB+UE)Z)(lKZsKgd0!~5SYNgB!qsr6thiZr8Y2= z0taTG2U_H#W0yrq^EMSX@Sz{{c1C7aA@^+Q)atRH&c2jleFn~v?ruGnXr7k~L*>GO zbBj88785i>L}?SjJPPid;c8uxq)I(KOrak{hOy z$>tf3y_4Lsw%w*ie2mUQBh?UVVC9q%b~WAVpN=o$P|&k8(ewEbMYaXjgkMVJ+qGQb zazF|r(){aUkBq*5#!DGfb<26HMj{kaJHrc6{rOg`-@3LP%K5;n@-p1);aQq`xea(H z-)IP6m72p=m0brhFpbc|DLr3p2FoG>v^=^xTBYK~&Yh4@8?G$PulZx5kzxTc-((-Gh0hY@l(EQSB(+q_s+lfuxj71S zRMH|mhwH)CZ(N}08$i-UT1C*3RMgSYF?PM8;!NK43XPYU6{}tv?9x8LJM2U{=Y5kJtQb#qsEXJ>xl{|PyR~hNI()m7o z+^&|ZV#V*gHHN^RxuWsLNGWI@%CAz*#o*_aYQHa-4X_To*0x!NFOgsjG55I zY*tRBN2RLhdw?7?!Xd_S_Uz;SMI;mUZmxS@VDVgM)&OFsa71a5YC?nuK&@kgpzb}A_V|)KB+0_FO;`DU{k`^wJqCL45KfUW zZ)A)OPUvX#0NUaC(=LsziMpWF>*&yfMhNKKetC^Ru4hVTAOxzY&nuwwc?Pw`NNH%c zAqcxya;(oD((-&u0cSaJKUXE7He5jxy%GIzjHPDABF;DNU0=8+a*b`Y6u2=BLzLp< zpPc~Czxws}-~0L>pUFs*i^NUfStmFnU9+*K0Yv(=hfo^i3;&d)(*++&2DY5M)oX1($D3jN{5;Hi_2 zIb`)4)sEN0z36a&F^6tBn|P}C;N}b%&rc(RnZ5DG`ypWo?5xfa;`6CdOfps)qIv+} zs7)vi#u&t=vc`biSw!^`=o-h^tJRI8`<~W@S zPdCg=hQUIwi{ACh+gIQRCW_C}O!gHL3G$8k@haj_Hm_D(Vqz&12L$vVV!1aYICN)m zYRto#<+_dP1iNmR7xMG0yl0)f#`Z<8wkxFI zymd7fWzP(~|C`6TBmPyXDtaZZNht50jJ?jn&YqTgLx-3N@=^QA7vqu|3ojnlQ+^BD94d(g!02!neHQ-!JPM|$HV-&wP=3*KSlG0@5>!W=3qu42q6dvFEFfJ zaI1OBB}jfXW^nm{{jcC%n_F_*_G_LViOe4vH+BV}B;{Wu^ORS>EG!%pQ1dYl?EQ5# zQbjuBXL*}>j;$`fUv1Yc(`VfE=2S-TA;FNJ>UEQuOC(N`hn7TmmLfjhO_qL~6Gk9U zDZ`d)?!E`rg5`W*3={qt{O1TfEv$X!w_Wpe2IjaO0bNQzR*PQ5$0xm>VwY{q`wmpW zg)XsPHi_^W)K6TPIXBJ79y=bcW|Sr?WGKJ(R9wOux9>c8{w1VsmPiy%WNu3(}P zm+Y)LWL4X);fQHX8$Yj4z&>7GQA#?gQdzr0L2fBRxW36~it+F}{Scc1yLc@-hA`+S zta&&@PWT6%$Tl*=B$rr!ZexA*0s3TN+}0x>TdRR`JmvM4Mf==D0Nip=-Oc z)F!d->iwfL)8w8lrWfQ6Q1Ep<=%jthyfS`F4RBd2&JKr{D=O#6=)b-nrqR^6dcb4( zme?N*3cchwE+{Fe6T~$(b^`qcufzk~AVjMQc7}UkY9tgg95gr8ZZ2+CZ zK#qDZh5h2+J!Oyx0I{D1#ooz^YQ$*Ru7|LEa*+r1Ha3UBA(lZBdv~O z8Fs$p*Q6dF`RB`EMS%fzF*FOMc>Ov(QXK`fXx(h%zvZ{NbJskp;c17DSvwucb^crIbH$Zw`fLmB`Q?qA7CIV!A019}Gg6-)i zA^CrkEjL{Aw=%o_eyuGvO-Ey~{(8JDpg5#Yh@r}dY`hZHgjAjY?77i+Lo(?M1kj<> zYb}pa)?(#}blb)hN0Fg)DxY;H8>56@ka4WL3Nq9Kro6-OHKB>C1u(%%hzw|{o`P>& zSY4b^ICS>nXwDa;rP;z?^b?>Y-?5i<5?gS?W>!&=5Dx2qI0PmlZ3!~O|Svvt~MH1>rzSF859TBwC5eN8&lA3eZeExtY$t;ll z1KO14*8S%dRH*lA({_dw-6ciQ7d3@yK z6O6sGHuK8@Xj8y9<)HoAyk{-?d7WE3v7l7BQ$P7dS_O&ow*sE{#b_iNmk%YDfHQ|q z@r6Pwc(qn~PX``rTzdyC*0olAA{Pw0+NGtn>kKU)PvL~p;|Wxp^`t_{wl{NKsGK&} zCiHGlL!kBcb47{uZ!N_Z6jq%yP5a12<$JVaClwD!2Tue}LK#*8?f*lzu4h#aP;r-z zIsozo8s1eV6u?srMuXW4E-07ynHy`-mLu&k;j8=4C0GzDAP0Y)(Wi4=)F$&aCiyJVf7|%IveM zipt{yypVorL*h+|@>fA>1>*=EX!|1GJlf(&^%Veyw+{IOtr1{nUl`@3VVvv%HCzIZ z!dGc6^D>Ah>?E{KaRb|ub!5`&mpM8On{E5U76=b|l`SS|ScAn{MI)pHjYjoYilDKd zIzcf-{~K?^oNs@nux_@{QU02( z)!t^%Zuw(LsY}CdUsUX*N?_&T)hjod9-ZiR5EscRG^zawM37A6@d&e@YmH_qf0Tg^ zd`pTEe>DfXF3E^QV+zfWvqQoRAgXcp?!uQ?10b{5^T>%nUujIo$7OIvUU0Pc6!09B zZk(EwAvu;?6Uuzefky!i%aPLFY>pn>*8lyZH^HOJWte@of^%aRZu2UHGs65C;?R6# zUH@qXtyf}4s6t9R*aHIqC)x2KKBjCAj0rxc4TIY22+VYN{R1f&q>(F1x>F+kE)QQg zn3;DmD{{WXR~OoQD$vonPboCIvhg`g45daWY_0S-0e?>>g$GPqe0hDzjX=AEKA{dT zgk~W`Jc^GTJz5F6)fQ24^Lgv1!m7j-Sm z1CNdo7NEObk^I*%vZk9rSz|Pb{+#ggd6;oHC-z+ z{48v2Qz2W@^y?T;4>%Jw9_{4Z=yYIEYt{*_SOp~{GeWJPz@pbbg9Jk$BPWl}7auSA zmV}{Hl@|`pgy3<$s=%OF92e+b1fix| zdVcrgya;-2UMpqzw(x ztuJ{5e)jdXWHx56svYOJuBg~eULSd*?N^)z%kQ&%EgE0`i3s+&O{5LbdTnP|LDERL z8hg(9PQ?kd|M|H(c6}l@uLZ!}AzZDty-;mLBj;_ORi&nFUb&_WU1VHMmPK3BoKI!( z5(D-mA3NSC+raA9@E7CuSu@kutK-U;R-ZID53|yxv+$>q=nBRpSIaVI7=daxE`dpX zto~jzb$-$M{Qtvz?a>FB0~bq2A9xrnr`=o}Hd^#p6z4dJ-!4v*af07UtLCG4#%_4x z(=n4zIksa+YROZk_aS6|I`j)UOs>d1EMmHgNew4-pUCGrzM%z>Jb^%U_)G0gA@5si z7d$dcsRgA%Vlx(2D19M1Vk|7Mr)xFlGq7Au0#c^0U2d%Erv!AgSx5+cGT$?mK&Os5 zRidMDsjIr$^2we_?{Ye}9Vu8YDG+_#~@!l<*+m_LBX37zB!JfZ$~Sw0IjZ!;uDv57}tRME(Gx z6ytk1+xlK51cZG0beaDx_^2C~*ILhvv;rspCd=8gA@}YP@h?sT-p4*MMtTv3|B=`= z03s4F63e{IFuV5|pQQ$FexRlWjWL0|iJ*=7TQ`(K@<4K(!vBzT6S91|tEm<6aH@Vr z|HAdeTC{Md8yPnudsu}~MvoYRM3ch~i_TI`WKbYeg$#|WQO}9)?o#FmG{pD(ehR*- z0~UAz^B+j|2l45Zu3FK%L&LBd=mL7>=8qg<;2;acjgZ5TI0+Y+g&w{LrX<>(x!#Dn zi6-WdbinN$u@Tx1kqA7P1nOe-$>7LGoxp=?J6SBTQ9oQ}4;n+5+;uuDrH!6M)2Z*S zK@%Bwm(2|rT8t*!%rhzhSm*@{^-$nTFmY#^A8xdKt%ksmdH7sI7`m~6ep*@WL9a+l!Oce@dYTnS?1CvORJ38}~@evna zox0T3T71%sXdfcP;t?UX>BSY2S*^R!1g?$V_1c(S{^_}sg*x>eQ+f!g1DagZ?R$uutpdL}E6nKYKZ|N)xi%+J?Rp!9P z$;{0SUceAkxm@A2TqHYkWr#~mgj=494 zcc;zP6MDK?B_-l3G*}!Y8*B=pYv;txlk_iGy3(w^2>aOKh^0`Y^T97UVu$= zP#?U4r}(i^A{oF!fZ9Muz{8{I!J6xoa_ZvAq1Lc9o@$2tzj&&MF+PeUULnqbx)5ETuDyYRq^NHZ{L}v{9%?}KUviYkkj;k zK*`=FIv({2!9bVzBGRXY&u$!@&`Z7+JJ#fEI#mWucNW)?%tuk*#7WrAIWzUlUT)qv z5S;sWJGl!l1Vw&19kzjPW?i8-CZT!+$pwgWgh3a+66~4d(-eU5fXZ6Ju>co(nS}mP^`*rTMl{QX zeYVNa-R8BQB91KbqyLr0jbtxjm+Y*8tYa-w$n{+rK(x-;IpF734KyYYY-yC1zQnHd z#!ymFR6)n%RGY&8L__6Bo5lPeP!gOg2?@!uK$Js6k~ZINJ*os9S61X@g>=XF?%msq ze5e8ZbYR0i-1Rq9$sU3G_DQ=OxpCiLfwjutm|hhwGst!yIs*E*vgOKa9tJ(J#Uck<)*rvPB5G>iiqK1o*DT<+UnT(aGF$_LeSry?-zYfnzQEkvVjlwfb3e+ zh&)B|MVVX=`@E*!`SA*O9XY6TQlh4xW|m@2Zs~}5KXt~c@ZrcxqFW)WIe{CKiKn26@=p~UE15d&28wYS+6YxJZoi|XsvOkpL1oFzL_fGx?GCYN4o zrOr05`Hk#iV(HN+lxsyOd;3kRGCE>UCD~TcC%oTLI$4=?QKLIecM2bniE^rz^)F*x z%^walwH~?Od9*y?PUbJ$+6r=V;Bo~2&-=ao9)HV!AXZ-P94t_@E2If{4O~yXKTP7R zmaX2Isc@fB-{@9rv{_b2~x@du1ddE0Of$Uv5`+n}zGkQT>nF?_% zQ{>jdJF@fhzumpaMz?@JterNoVoo#P%o~$v?vD+N{&sUG!>sywcD&P#{_>l{6IMRh zLUqA%3IQNYI7{lKOfGc=lr$NX>lNL`tWeDC+E>NE{Jy`aJ?>Nj?Z^%Vfwg<9sLHq)bC||j;KS3faTQmE)*VDFlu}ihv_DsNLirQydN`># zK+GWMZKv3gvvSUgv$t0WVaX+|Mtrv}_3Pbx+q5U~4PDfL6WqbvwqDutQi+SvpNjlW z`^_HNZB8iBuy$vBR4^vy5%hyL(PM71-)Wz2)Qz6+J3lRW>@&wrOC^jk+*&GP$-fqP zo6c6O@X`74b9c*`QfilVji02Fzopv<&Krr&pIujB2HfTB!uc~o8)wc2vuW)p__*&} z#W6SP>qP(Yx3e@n#mZByF6|B61SYrqw&*XS{oKS|#9{xw;(DC+N~O5zF!;aw<>!_E z?_D{FDu`Tu44u@L;VTAHgep5ef6E3 zKpy&KDY|Ft?`(em|4$$G|NoV@zOnzyx7@%s$8i0g6A_SDLIF&XfdZ?CNuPRs5AllL z1d}yIHMKJUC0mM`h5Q#HuBg$ThCta4T8NE&dwCxmI2RrffiX;4)#J^4o}r=KC!&v- z$u>a5g0M(9PHoWodonc%TSgN7K5448!)kk>YM=V0-m9|W*t|#<+jhy+K38!vTk{t7 z2@1Ytc|$RJHlnt`j{Hn5KI<9Xo_;yf=;Ed-qRw=zrD%$uu6%;$Y=8V;D1a5ekva!rl2$O-+jxd%+)OWt<-Y>U*}*t+uy z)LE+!nF|<9<{ewH2@ETIqi%bnm7VSeHQ&>$@W*>Kmy+dHdNiq?xqa{?zY(>_c|Le| zqORy-u)v48D%O?7;OKViDc#${vv2yY$YU&hfMttcTE;%BmHVAL)t5S**Gps9%5F(v zXW>hcN3PaZ`z$&~_r-l&HVHvn*Ndkp4x$+zQ$qY7BjG`le;{g*Dlq&l&Hoh_npq((;CLFvOif;#*nzn?Yt%3Ca{?hxOfo7o$>K>rF*bu>2U=)UYcKQG-59?YUbAzD!qs#N^|Vn@m2=Z-p=&9Cd#K>_{A_EF5}dTfjH@G_dV-up8=Uj86ND!6(N!@DFfI_6|3-WfPN5gKb?g( zb1M5a+&T^@3Sm3d{)F8?gV-hTY(#S%NOox|mjkKvig^jCK0vsz~aN0bqedv}Ox)a+-w^e|U0ovLmUf zP_MhxXDm1Uw+AojzRd$xbPR?!YY2HsM(f`-NX^p~LcyC*J<#*987{`?0m5eebyVR78?2vN&)S0*F; z{nAC8rKP27{0$8ap+>t?W10GwsRpyHfx_ew!yW{Re=d7eZg-LmsfnBG+_{EWXjZ;D5lE z>HJXfkm>j7;MdU_zu%;$;=(sc?%Rl38G!9$^Sj?dA^!QsBOOH_1~LrTDlBXb_j{n{bhwCB?| zzrQQP23E#Nk!`Oo{du`358n*plK0MleiZ7IB0DRfb+(Y68xm#w`~J5PO#nn({^i!j zdC?mW!DBjS2pFk-Fa#40f=Jb3`|Ypnt;>svz` zgR|3n&Q+*kjG)Z6DW@{7fTkB4VD?V{|EmV)CM{OwkqBr6WVHW67JN{G>oe=I|6SJd zXjxYpB`L1Ig9cGABT`w)X9l2y69UZ)e2XItd4NX9TbJQZa0i@+?s!5wj5sm7jupHF zS!OAhS@7(Jx)~5itTqWPv?gawG-Ccn#olmOEgfdp%r?3m{_wyv!&PBf+Q~UP7Ytl; z$9Q;n_{$;A5E5NOGX0LPr$HYD2~cs=LeMiy?Asc4;N{n(%_Q<|-17x^`{vGrt^b7k zXnIVB##W#ngn35smi2BAkG`PuQ#0B6qwn5x`&zk^z?QeOHOzb6hlAtvDTeKv4{CGy zpcr3@mQ2in_WzhLiqgM;m9fW1YRhS`^|vQ>9y8dobhlm>?oH}U5M`uv!P#49$#mXlRAJ!ktte^H9vCL%UHDFL$rxLcBOnel4=K7}!y^BcN2TwR#w=Vr(Cge+VZ1B7_VDd%5xn!Swtc4lyq#$9 z@b2M7cBm8?eG7f>PmCu|xxZ-Y)bQ-7<}_guwO8SR;kBk5lsBDx#a#eOly8{%8S+}{PJ$nv3-^L4Y_r*}ng zYQqunWQslNNow;Za=%%3Ql>uVs=n&tmVMj4L1Gy zalNMjC9!g$%nQ$@<o+|_T7_pW9hr=p@i z&mO*%MsowxsbN}888cm7ci(D%p4)eGJl-Mu$lYoROpY)g_pcwWnWbjQ>>6`*Z6{1> zi{3V<>;zJ>6PDIC+zFE^=cbucmLB)PagR*Eg-@5xjS4yv9g@oM9F~+I-h+90Z8mmd zwrep^8(VibRYGGiIK2KM>Wj6-W|jg0&&Svw`%Qf)djD*Htv%-oMYRvjvxGiZg&xu_ z|E8>@v5%|dy?>M6Ri5o){!o*PI`3 zToYCnW~u-CoU&vgT?{{etZY5?^71=u5`KqC=CT*~hjbIKhI8tU5tD%n zUqHp)qN?bmMW=v>z{>uvsiiTH_olQ8i(F52CkHnbUDWx)^nJb^oaeo!-8_ zg4%_y+w7|GNr{Q4jL2)V^Hn)n{GYQAy87&YzMymVtlOi)50f^U#Z0Sa5#@_#?=xA~ z-d^PT&N^N}m8V|B%e6B8sm+QapIwu!apv2S+v?)mQnU2n`l1gM=AFUq+D|FAM36mh zu5ZgYYj#K55*f_1*OD~!apfDb0twLlvLI_ zB&n=NlOUJvJ?{$?bwx!kjsPo?s|2Z>;D_(NX)m1%6c4R1ae;qYT+ro;HDuTzYaHAiC0aMl$T)l2c#~Ax zbTTN^nA>5j<6?5O{A4`Fz1==4LW0@1QJfR!fyg?M!H8LcRBB*0bRo~Y7RL))VD&;A@%q;SX8!XN1{OkBmi_^1xUc-S-!S~J=JpB@ zQaq2U2ux--W$20&fYM&~xM$km*!luSQmLwpUN-&t(S1dTuf%+=AZgwUz32!K6qb26 z-%lOMZyn>B_5(Sv6oW++X~*HpzhV2+DX5lZ)8w@O1dQDo5J@k!dw`$S9b-P2oBKNI zRh{g9nf&2KAiIykNyYK!Lx1A*L2^_x@!A13wBrZ}#M-pmOPFn~v4kLHj@h|Ek~v6> z{(8f`2u@4-*pFs1cB&}&{ox0lzVutXZ4Fx3n{Y15&u1jNR zQr^7jg&N0Rq%ao5Y%ow?h?7v^#F;%Bqg=M|ujs!+vNCnbX+79`or5##Y21@~&?tl% zKL6sGHglxh=g8LsKn5a@+8|LNq_+FNYWwQ2D%WpYK~TgXL=aKHO(;r83rMMSgVKnk zbT<;qKv7aU6hs;pO0#H-E=g%nI;9)V{M-`1bA0Y|pL?IP|8OtCwbmE!_s%)T9CM6Y z&rZpFAmiGP15ae&)W54oum3r93siBHhicjPak*SN zHX{*XbEa;2O{pzwTwwa967CA#`zQ$UqaFLQ@kz8m;r3yk*69r zow>HQy`tmYrm2zhD<+?_Inq(*EgFZE4|*c}BaCyQs!^+XMlZF5QM1Is=dGWJ)6OlFRh&&lujb?)!*|B=fM5<<3cep;GoKN{_?JVNmZBx=7i zsN7%C2VOn71*_RAhduZ2QQ82B(mtUV_V3Pqtw>FH3d=C8;g2x|RtpV;d0NEm^zkT03<0qf?n%WZvIFN50LdheYriqcH`2%-<|7Ly% z_%4Sc>*!zet(#1cZzV^?)ckcB+$X#LgRDqLa-|8a!HwX1kYVuCt^+; zUBwX;K^0e(e_=*=;u*{A?aFRM{9vjXHs6y}r#Mf&2Hc^Q!#S=oqID>e+%;jQj`{Tr z4J%x;LUc9DT>ncLw4kE<7ta?_PzVeLKp%uGc5Lq4UGvWsRYwNmbaVTJD>bGczi6z@ z9yasM&x=AY<{3%#KF-KTTBv6%b@>>|vTAl%P3zR-vzi_1^lftzdQWDqS4z6Om!@lC z*rJv$^I@MP#w(^Wh7K&7KvgkCYaxTQJ{lfJ-lw@WTUyWBdevtWWy~wReiLya!Uy*z z8v5r2_rlyGn{^APRP+-PVSJ|t*FxQYp>}jt1s*bHs{epmX1KkINduHqbU8byYiG*Z ztG#o4qB+>z(^Y?>r!9zbplq_wZ|`m%F)>8i(F+zWq8=OfxGaok`10<)M3?Gm4S1c* za{>WAsKgy_l}N2GX0}GE;$%ruclk%NS9)Mz=zY>Lh|Ut7N!Pbd{s3LX+2`@Hdk4#K zTuD_ox1rt1wCLzK=(5nQ+}8u3e(ZFCBpGSs#J#q-gxS^N7c%e__lNXpTXe0t4bC6S z%Zo~w@+uw|qqlJC9;fnNk;DJ%;vlTDhG1h`j7p*0ODcAw7_z=nNMUzs`!uyX@q1uv zf?DQ08T_`AB79;WG}~X4)B#BIREeo>`>#9O|K}H1tqWHG2TmaoxX*}&--&Y^;=B%1 zmz|xRXC44S|5(&}Pz6JnS3F~}kK=)m-wYlVJc#5D;M3V?Bg4CzWle z)_4cqeq$fQm4{D0e{}b`6`IG%+Ljhg7y~{aB6DO>FnGlAi0&=t1{n0{H{iZ`|Z z3@m*yQ9nbdJtoc(X!A(-Anp@9 zo4x}6D?Su2$B*1Y!F}-4Q_w>}e4*7JkdbLxp}>5qK}%o@Xv=Ni1;Aj(v9S_VdplIM zx1L(GxZW~R-%oi^b)ak);cR}>`_F=cvLnkq;2X;pcyDNpfzh_!OdQbPY!J~Vb9TeJ z=Mi2ZyO;0~_8`2Qyp$`c;?Iu)R1w`CUeg@zg+-W+)Og31eYtLTqfgd9zxQ+JR>#ma3X{FZzEW z#Ty19c*IK(3S^A<0Gb{ZjG}H2u2)Fc9q20sLxwNSz{0mnS7iMCF3?!;=IR73!?;FM z!MVJUi~n&QDQ!B zKG}n@EiwMScW~5%l(23GjZIBNw`U$uuqZ<6NyiNZmp}WD_-modN6AC7J$ANbm#B7D#iX?^2t5Q> zp58n?fBrntNJyxQIzW4-%rt%F-dKstL{Vj2nAZVD*A%4m5wE93Q~ulv>Rjx%7xZs{ zcFyZ5P>(hM#=d z*0{elqNd2;v$xE}PwuYEHEbT%s#V~mK3~Z;X-iFRh9vymq`fPGsdvkHtSgbT>OOJh zcUd$-c>6sIU{Fv4kK7eRr@|E>;Yz`DMx6Ef5He#5tuJwCK^R@Kg$B4qQQyD@B6$GD zY1+R>cDxSAP()lBSh~0r=&nDKOF*|b6G8Od!S>@6gJs>T8JM?V?VxgB9pr4Eyni1w z591~m&AuWE22xRSyyLVCmoF27-YO`4OR+igc1HZ+7ynn#uA3nCBV>5^@lj zHpnlE_Q5j~4Bm=`?8L}ON*^st?n0APG9v%uXjz$O-d!mC_3Z`0v4zGw&uU--7_cs#={3-)Lz?<9U5^(w%Z6q+p z-qGX}zwtxi&|b%cAd>H2e(gM)&VIka?9JKNHB)+=^4R}He7-+)!l^|!H(Q@_XwD+p zQ6felqg}em(D^wISeYZ?^CfK&9CYeW9sU zL}K)>;uq#B^6P!V`dR9K3DTd%Z~k~4wEx`v2Y6?r4S9Z#rJl+D$o2k%>M0}%>ED4Z zVy~qB{Tb9N2qkBWE75Cz|2*(2{tkR}@O+l--hDXucKG7X1omRzmjIsKU^BggByF#b z@SqAn(4d%3Irp*OO8<|Kq9end^MD<9*DJedL3tFbJWBvNHYJ+8Q#;c7IZC1}|Qi`j>nBQV~PdD)oZw=S(8*9} zeyv!=)`7&T$hc)S!#fWoc1kt-_G{sI_u&xd;AOBY1~uwheMp88qr zy(km+(wLC_-m#xU7g!iCbSvIlvAgq0j8X@kulN)*BXBAt{;7VU>**am9|xL+%LyfT zn7n7VE2B=;Ovlf|bUR`qWQlUj5Sx}Mp2;?0M?PE8DsLDCEZkjPMOs@$WC ztPi$)a!v(!Eycc2f|M8b4p`4^dB6(KJzvuLs z8JK?a=FeOM)(2~$ILd?JAnu|?h@qdK-SWK!+CyxdGH=Fg={Yfm757s3vZazX#7``+ zhFO@-OKq)a(sP$yS&VEut*_&_U=wXAE~i(nc4&z1BzMD z;VDk2Z>uA=yxvzk(wIaQF%WJZdY>=qFsappYodKZ31ERxMhono-vvxY@`uoJkK){^ zZQDFC_1;K>Ik8)+wUp!>$LuoC+g4<$4Ck}m?VVq-ZQ5$47EFv(v0%fzO!(!cml!Wc zrGfDlrG5$B8MgV0D!i_FK+OC;00(-}`1gU9Mloui`dxDqwdyj#*@EC!Nj7E*+9iUz zC;HQvl}wA&ejLZ&?BFBDJNML#jTX$RQ2htlT2{;hw=!kyS zfi=3nl~IQwd(nY;od5jhcwyL!IH-g68_)+!5ey^LvCR3*_Is(!!EdHM`8!Y}=sPmN zjcC#XZAAjmPC6Er#x1aL7rdwkv9|{VhisaJk4gZvinI-EKv<#{5%T~F!Lgqb`v>$t zV0isELfLg!6@N-Cu1~(Ng;|@f)pGGKF!W`0AC-0_JM1Si6Oh^B#c;k z7`HQCN`6RA^kny3hEkzg)-+d}*qxgnCq?4v88q0JA~Qx9Fdpe4e}3}m4H(zyJ){Tq z@|foj0N(UrTpJ1ke{raHrK7m&r?cN~(CiIEcnUcn&H9YxxSZXrI0Bo;+NBR%*9#na+p3$O0q$2uu=>t-*4U#f-rKIt@-xAOuG#@A!AT=P8)l~ zJ>c~JkMF?1chJ8(g}<(bs0Wa>gO)Z-`5L{}(zN=B+YGa!j%$lBFG<$x4&&6A4gOP9 z0yheo+k=o1XXaig#m<|uPwwAHY+zB5TOAJI|C68(`3T9LU|To$g-%If6YB8MiXIWL zb032FESqk_x^N;HTcBbTj3O#!MM-*gx6qehP%2*IXTtm3h_L+ytcDPtq$vD!|9~Ju zH8?lWk;;$2!@pIfq$@yl zI{4b7*tjl<9>a_L1 zz;5jQVVZG5yy*xZ7{32!(%{P6q~9`{%5xBNj0EVgr@&?B5ReoRHcGekHJX50_KSJc95LW&tfdu{_% zF>ArlEFRp{UK!b(Wl083m#g~EZ?6Yc(dnFA^ZO|b^)$9xn?IaM5drz1vkiRyaW$3&GnEg?C%N~3;Hgv3+m6yBX#G8eb6 zs%`U7?#w2CxGUf=R!o;@Ov8mYbcgUq5AZt_*A ziR~M&((EvSepOPhEpv#&iOoM*Ef$WAB;>iWoR>m&)fi}xJ#h$7UW4yaNjCDCdu5py#iS3x}Qxh(O!!{LC1qu)dy*i?BC1=yqK0A}?r#+E{-Fi8Titt*z$(>;oklni*j>mY1z+Th2c{A?AZ(_ewHQ9)-CIFz^YcH4hN>^NZ327rN> zDlsZIT0|&MYV@Q)#R*-y{mNJwAs8MYZU=x@J-PlSw1vKGAmY%$CLk00ZK+UeHe!qk z5$Z62>B*U4n3)*{pCy|{#F-j4mSj!0rM8(RumfZUE13CDCN~_$n??M1M$YCV70c=R zA~33Evl=KrH|~Vm-4Uq+BCzjKDN>D#85&Y3T={cyRS!{y5!zq8Laad2$}i#<)Nb!H z-Jx)5h8k;(mtVvBTbu~&yH}8|Z_yXEqvYJ?T|>Xl2>-oc`W>)x(FZlllBhf2+s`k>*#+=5BL^?=mG-- zj1c&%5TH(xR^?WpsLxz+^>CP^EAjq-9R`5~9J@m!cl8v0;@XVZ@x31GoJAz?9F21v zo$d5iS^$pvCZVgkzxQQ<@Q(g+E;-$wX-fb}@U|j+B06Y;@%GlhelR_RnnH*MMO`nO z_xFhXI87a34XL4_W{^DfUsL-mDG0VxCM=TsB!K3_PZV9nuvyG6EEo*egeYVGAgXdg83|@NTJ9JibQsgE zhxrd>jZ%=##E$Ome+IP6fAwB&Ur9ygSaOQ{8JF67&|snYjI$4cVZFJ}&l)U@w)jh1 z)p1FEcwLWmoP?Cr;woy;mvC0<*3S3lgn|Mm4S|d67vMd)eJ)1G{77l8wcP8{;k1wU zuKX0FQM|MIyDShb)0oDa8r;cuWNEtRyGB)6_KY?pUYwni&hez;O}4QgpEhofN{l0c8S&Cr&a?&w6mOP z>5^v>&#Z!bxwXHdpj*d?vy^s3(9bcP=5=Q&_+@;cV>ZWjA4zY%bYTqUa&L>vG>1p! z-LyKI8EqA`nHA+<7ZKreNl$9MJM)@<=ej(>Bh;4`7k+i}uWybcrmi|0a-(A(Dm{=i zQeVWzAwR7cpuanvcwqHGnwg9fuEl%BIpgn(_a=@}b$GxdT8jMy;J~ZfbV32F|A+VD zvC+p-45$gL$zM1^{6Gw>k$mbY?c-H>TcJPt9184N^)o9^?L)m{^ zX>Y@YyovwHZ2w0GUJ-^I60S%dqA8vMUt=?vgv^6oToszw>4}LIqJHPe2QP5;CLnsu zRo*WG4jnqjEpc1;4#f?p4_iZ*1T%fPx~c$Sx?keDl7|Q@2!H@cVo@M{X)|a#kJqXh zyxwTR*pAYdMkN;YTyX3O4D0XN1m60(jN`3iMT$KAbcnwA1yGmT&h7x%up4xkreI<| z3-G6lI=uQxtT>X+BZvy2i{K7SU`p%t{4;u+dJa3`;i(`mVg6Sxv`bYJZmG3v%x@{x z_;)zoqy#_kbJ1Uu6Ms2RFkDrD53`V+V4JEBeA--4BJY9%@@Hhot65h7^V*Z-sQ3mv z52r(dybRs!PLM^M0yVP_K)l^Bz@-if=jzly9R+%&_u3cF({fHj^6vmC3t3$s(9Fti;H^cRKcYa+^)s1)|DEH( z>vhL~6nxJCN`z4i$R^1;PpZPLJ$XEMfyZn$eeU~XbX@u;NVxG~+ncn5Ct7LJlwKq}43qay_jB%QRIFhSb4Zh%mvbDK zT)fW3ian8)5Ln~O10hDQ|T?Z0XHdhv?zOv73 z+l;Zef^ylY*1g8J6%Sgk^SQ`^8pG#eH&v7n1{*<_Wn!Fti@HHS*p}(R=`AhojMJ4D zJ%!e=TWMcj*XQTxYYe&W5B#UJj#8FD{XC)IzpK*XChkILAA@tBvRdeSUpv~lBv_e85x9TsW z)y&V_zcZN%S`R%v;nv5hwA@K$K_Sc)^dY5DqB0H(P}FuqQELJcGN)sV{)x?rM@*lf zgA%j$@RI$jm=V=b{IRRX*l{-*G~Z`kyeNMXE7y;Iws^q|9Cw$;lVa9mkTz56T|@Ax z&j zsA)ag3pZ4t-z61@ddnc~QE^?xzQ52=l@hcK%rn1|mTIIL-YpqVjGBV68Lr%^mZ>gd ziAWJm31MZ$1)`==g_Oj^#I+t0M`BYHs4sVlkZI|ITSz9f6gH^r>pQ4|m*Giq1r_}% z;L-TnQ(~NP2%n33!ea~v`0e5HMB$aCszQBL?F_htHp6vzY@wEwfyW5to2H-B`28xg zGbzph;kb)fm@5Lu5qIhSH8414*p59ry9RZ@s`U;C)tNjyZzJwwOI)TwpmD{^!@HsW zBQUE5Ri17PTa5B!w-L*)Y+5Bh$$7YPx7Kd~it4=I*QZksXZD;Fbvl)zzZs)WgGM?W zf>#0_8<*J^N(RC)80J|5S26!%`OtP>JlF+VJi0W@bMN*!C{xRuqpXa4kgx3aUx0PK|COmsu)4eAqTdr^O$07Y66vS+ zKSMJRH`OUf8b_M1Jte#M%USl;{ZV-IGl6(n%Q5)UFod2H?O8Z^l^lECilsK#&G`c+9va0y`;-O*P??_)gY#F2G_Pi6SXV z6v^bv(zd|?A64^ph34`$Nt16}(E2G$7xFZ$mr_?ep4z`T!kviPLAaii-!tF2!6ua0 zDXNmLssq+=c^VcNI^#|*+zt?{`tjh-G^nYYL#?t5>IL14Q*NlD`3#zy;HW7H(*Cn1D ziO;H1|HkJPj+lSrGp^i=C~xmzHZ1Oz4~f?3g-YM6LE?cx^W;AZ0!2*_9LsUf9`0b` zwc63ag)RT2{ttS03H{R1`Y5sjPuYPAwqrf9)hLP;W-HXteOQ$Q2xLQ{dztfpjh5(fPlR zcGuwIu_bO|QE&Y9v+rCH3dmFx_S=&(I+)s04?~uyw0$VR?XNlWl{3^;swvc!f93MP zF+cKQ@HkZUXWr_+eR~K#p?>?0(fIWP3*4=CW#;U?IX_}BMM!NCk0~I!2Fb!PH^PGOd)H>Ms zJ>!>*$00NPwHC>3c~3Vh5B3s1@M%kjERG8XtG*1G0s^-T^Ky|jKD&VI98oMKc}!Lu z^2ZNZqUULFAqqmzufXicnd69)+p@pM=I27Yam6`I;uAP;*TNT-D@PIh9ncZrooRPkHEirNhFpL?vFvHH6b=0t?+=D;CkTnhtH?- zJ|3a?LeB?JRSC>WH~3xu@%(#_Z6-IbN$X_&k7ghfOO%3!LQLi)J{dYJxhM zikq%KJr|>}VbqO4S4FgrQ{TUFfglMVjE0(nfx~t*DE$TclCxLb`562WLb}DdpNaB{ z15@B0`N4f_eeIAABr#|R@(#WB9c{+)V>cj?%tf@!Pw`D%>o$7li$AT4;JDL}!1omb zu!iK|QGRJxjUUFn^W&|_>oU3OA&YwrZiXe25iXqiwuRIy${lAzp-NbCK>P-$Grw05^<#_DXcrV=a;r&VdMi7={uBj_m5IN z{6KKl?=0&l0g(1m;x_ErW0gpGf|JI9^a) z3)#@#Sy$nMGSVD2V28P{WF3^;7CqN3oxw%;J6)KjVv&Gm0cZ!gQ7iOG;$N$g7uLWP z1+OF>?T2$uS>$lF&V)CZi=-YJrP}^=R3l5@8!4{_Q=E6_FL2=c$3tGrhDD{HF*yytZEyVpgq*~W)klEG0n7wa$nce zh=pG6WG=Oy)@NJ=(+=R3ciN9M;u>aeetNvPId&sM!>Qnf%Aa^(RDrE?;oVz}J6kdv ziodT>)?2yb-9*pdAL>o$-La#^A-pU*0Lk+*nxyR&YkpCNI$NqnrutihL|BTv3utZS zTG3qnJBr84u3paACm;FHTHZ^NS~ zod&S_;}loEh`3JKxMIJbV=|EUNXcT5n9$KPgcTGyIzX(=Ug$uf96>j+NcGpDFMACz8u2g9>G@ z?tqj&diLD;jK+s|4>5Nz+PaQpH6+9z(z0+_YE#)ty@^uAJsT2p(*b`35x*@>D++}$ z^xA%+ZvNX5`}fLLcZ5ZJhm(;kr$0nftF~>LS9+olt$n$*%?+$=@^-d2Hs&ag<8po0 zOZoL_-X3;-kjO@WU+tk_?VLxtG)f}cn{S4j0@b9$9d8pcbAgXlT75B}wp zS_{MGAqIRl!^B_bKp$6lWTMN!zQtDwf8l2ZCDKxR|3hOIh?auS9))4hGP|lBi9s!D{e}l$q{kdmhxUOfn*KIHvFzk$ir<+9D@kT3gBU^25)-w zzzo;qrSfmmEhfF;5ueJ;4fAEsAYk$0rmK5k{3VG^Lw7GFClx$S4^cbgP! z040Rp^J}H)h*0p?eDIMNyIu>jI$M^DGXy9@+5yAkT-+4V7_~q*$I%`QC6(|xWxvQ9 zA`INulQYR!N){E%m=a_>dJR-eZmkSft7L-E2qxQCXAEFSMg5J^W=?1P;Z=u z0?z-nbp^Te&kGZMInyvW)};GbLQnAuh8(kHIwBZ(Rt!ynbjKp2+<0w+EJe;aBsKKN ztE?HZQRd2~Up4;mhE=cPI|$H@#o$gmf`I4nbVf$=(+rBdJkZ0i?_;frdh|I$58#t8 zyY0^Buc4}Ie&K`Qbj+T(O+g`84QCg z#6^ir-n-Z5Lw%mduc#eOEU{LA3JR4r9jF@RP-rodm;RYBY)a-7a0mdwPG9dMR zM|wA4Kx`3XeMe88)7L7ro5SK&L+r}k>ps2Vcja<2zt9c2h9Kn+up`Z0RymFW1T>0X z_$Wb!ZE;-PmfCJx zYJ>(5gLMN0*41B%xzg|))EjUh?1-0mZbTkCw{DFlF@3H1f$=Qx00t^rk8^#$|F~SH zm3Rtj?vK!l#O7AAnRZvV5>6<`J-StzF|0mNxOBvVv8)`IZi`dQ_gqMBSzp0`#z3L! zQ47XMp~}U{4i+m^bvf>rO;#5sE=;v3Hd*qCo^VvI^9ld9+z`uNsql#+mg#p}y(bz5 zXVl&URf%9RC<;XJP9tuusN=*(GIHN?+p?cH0;qta6_~W*WF#2OT(mEB{HWqVkg=rq z(M?mnF{VQ%nz+DI-wI^<*4P+=RbH3l@*F1fS=uP1jumXjwL45kvfN?))0ni_Pq|R- z-IhQ`n=!oeWmWmsC!zwiq{rZ$xt3+eZhJJt{=oo!XffC7@jUFPcer$ zmYC5vF+ct2`5Em}@Z3s_$ZP9jJujF2Bgp1}02zxA|Er@gL~>38_8tn2@C9Utx=Ye@ z{l&R{1+8~&cX?7dz|&CBA+hj_3Qaju6)34}0;WlP-zQZb!~Cc1L96cRq96A%tnOkH zBhJ~NOf{0j+(=u4ge8`ht~FE&`BZ; zGj{hDco&BsJbx9|!|CADuIxF}eB;>fATDyewY4>z;#muaf-n8jH?sj)MSDG_7j}92xYscGzh5c>n+L$+w|m zV^wr#-g5;w_-3I^sI@D~PE`M05+e|*`2VHT9oRX9gZl)>ipF&85d0@0B6~aI=KV+i E1)_*QV*mgE literal 0 HcmV?d00001 diff --git a/graphcast/docs/GenCast_1p0deg_Mini_ENS_scorecard.png b/graphcast/docs/GenCast_1p0deg_Mini_ENS_scorecard.png new file mode 100644 index 0000000000000000000000000000000000000000..f7c94ca7ae20a5d80c48496e308358e5ae5d65bb GIT binary patch literal 158143 zcmb5WWk4Lu)-{X00Dx#zdhufE6;b& zj}Lm7o}TXNs@+w)_Fiihp$c;1D2T5RAs`@7BqiP{K|sLlLqNbX!M^~G(2rDJ0dJ6w zO5!3Ar6Yv9z=uv_kfe#MECd~}4-fGI5)%R%d_I&`Hh}a5bPv0 z93dcZslYEtNhQi-;QssODj+A2tPHQAtu>Rrk*$F-lbf|2_$ml~H(p@Z+So~-)Xmz; z#*x=efc*IkUSJ=5n3XDwf0y&l*wN6z z+|CKG3@Nx=eFIx(CjoMD@QwcU^QWFp<|hBXla1qF%>o+84E~0hh3O6R|GErl?7!;w zUst*5+x@xdxmAB|#t$xvSKh(g7^p6|Qi3e}&&B`Ov42;NA1Iwy)WKNa$<{%|*49e! z|L+6mXa2t>{8czVGuRyeVV*y>e?AHrUl0+v^6v*Jh!|&tT?zpq3?cdMt%@7u?{tJE z9K*Z5G9_9?5hH7=77giUGUL$fca&uF(h>;_S^WtdmZ86;`(wC75VFLKf8a_g)~Y8k z93($oG(8;=I+`5vEDn7f;TgTr$rwFLOLjlBZe=qa8RD2#>QTf%MEC;n51ViYf(qj` zGB+e9=|66wJ_Ed8;VFfo{`aOp36oz;9|PTC{*QNoh{xyn=>0C7j>n zKi7e7-wpL2ix-YY&#Ww!K!y4K_8;mf45c!2Limrx`;wCh8|y?9!pr&pLmhp9Hg59# zhY8}-+6io%$y5tf5bg!8*Zku>M#F8d7HTZ~aNb@#S4)oH!BckziJruV`(&?Wbj=X#93Fke|hQDc4 znM@QY)9_r*85+mRa=!k_&U5($FZ}l7gd|PiY5j30(pC_G4gB(6UKW~W^c3O?OCemwm_vlXxY0YkfJ%scWGNA8D~PnCVd?)|{&)XR;|v4Eva zr>@Y1?)6HYld9a8yKD<}9Rz~&G>g`(EdTSw$V;}WsV|-D!9z?*~^6A3ioUCJk z3n@I+hpwl%&%MJBwnF|sf^DQc%9wBGO_@Q=qD|AV;+FBkP0q+c-Rg((f;OOqGdzOF z?a;8P%2_XZ1s6vc>#Q~-xm;AnV_iNK=1b3irqe9;{0&sm*JaVVF1^M^5NY}0!YgQ) zC*!QY#+HymwHX%l>8h`G>6B&Y9K1vZYa$2kB9dF^36kik_43TF+eC_IfX%myR z?qPG40CBxw(U!^khdQ(I&X4;gbrU0rIDZ<4DC%QVerqW~_B^QJG22sbbX=EJnVNpH z-T=?W4`(8-oo;VEm?R;?>-M#*tPPVZ6unw8Ef376MQYQpAM?(SR(17iV&oKd(i}&k zt40MELCfEYMh^lhj`z#j268?NrZ_<9dYlBdS4JsuLlvo&#PB$78wvR0eQ1^-!M_?J<1^3*MFN2;4<6Y)5^F+FH=pf|1NkrlVI)W;*78zqSkX>8 z7z?Bio6qDFEt{v}>uc6v4(m6g;v5y|f275imbF}_kXHL#**@OiyzU$9dw+C&c=~X) zn(;Y{{NiRO!~4C3ttGPhjMK?Yl+dGFhjfvZ(QpcfLni_a{L(eOQYI)s+J%}KwfRi5 z!QtEauW?bqyWecKEQx>hDIdVQiP@oZ!-8EXEj1LI$ z40L}<@4yhE?F{IdA~nV+acEk!lZkxytHTMv!Bm4VgixheyW^GB6|jylPbjjSP{MIS zlhQ))O=J5=yqCH&&utgI?hcGMx~=5a~b3*M&0Tb8xV+@szV)mY6>fEGG5omQGZ_A0G-N9gam0)zW5)gk@a zw`PvPAZ~gND*k1?Xdgi%ZoY%+SdNv*7b7l*jr*k$c_nC?BaWj2J3x(l$Ke91t;yRQl&^1o*P6dZr zBfsaX(iL;w^MuUiA@HPa52x~JSycp9t$f-~_TpT$ViVU%A@uy^cn&!=4p%5FR14+I z=YlRS3muZ`+2uvPf+I>z(8YhF=Mx`HoqVkB ztg!8H>Sge|{$2dmw)VSwrtYwAlFR+y~HI z_POQj7A)vZpB}Ebb}+ZUL`IMZt&Brl<)O(cZBCWxuI`Ry`R7Qzftw|hXfc8ck!Z6{-QF!fV#h>`vy$iG_QH{K*=Z8SP!?UEND4<`TYD;iEsCzM#91MK;9{&kCgg_&mbUwY zhn=Um40qf4?j(%@p0$0{B1{{?An%odt-QP?-HM7FrBf|T3RbGIIB|)tCcH#P6${h^adTMCATJB?hZInW zD|y$(%L_eB(attIO4k_~rOCA5xx47h+be&2ZWb{tYn7DQ6(Kr5q-p!{oY;p1R7 z-Tu-^ONPEBy_EUL}}_P&*Q&|@U5=ytELuJNz-B{wl44t zK{~V^UuZjs1}_U(#EFJw2*Y}2sI5k_(!0wcnWajbPo>VXdpW`sBB7#)Vv-GO?24CH22-hN@jw-ilmv=dXD~d3#bt|9}^XzxUb5$nVYGQ z6S$c2rZaAEIveJKyY04hDa0dj(=b{zZT@v8HMPoNtE){7%uK#YZcTq)9hgPAFBdohwmY+s;j(&rXoR@0lMbQQv%3phltQO0rM-M6Xgr ziZ$mwUum2*y@nVtCvdxlr72EA4r3#h@|L+;u9fqbXaS+u^}65MFHS|{2>TWhgU?p) zaPsC=m~R>_h%7Kg+dcVY+K++hBDTV???vG+|AzLSX=ZrPkt?$(QCz|458dVZf%lPE z7X$QVA#|-UOg8#{p7zVoU3+r_99nriv+E1Q+n7hz!ACP&iMlpRxLqT!<2L!8hS)~k zQePx|mrJ+M&Q?B9&s0Q){hscUlJBgUO*4eI6Ry6Nvio`K>v)4dT6<8ftc>|Z(!=3sJAy%q-wUGtR7Zn6KVU3kEzVhgd3lO>1ZV2NO``Om72MO~a=Ecr4ia8rM^u6E@9yg2|FrvIP^DTTlkyr&fQ_WF&3(fVaq)mM>E{c!DooPj_mV zU#F(qC2cd9$OV*tN~t(r3DK&sF#mH#%NB?BB^w5uV10(*aj@N$0%BZl-0wf_8a z^f2dQqJYiQ_2F8|I$n*!C0?>~Geo&g_=A+tW2weT#TUSgGT@ws8ZMm$@Ig~)t6|WuU6d4+CWnrW0YlMZU)pob-m7P38If-*BFi3;Vpt3qBLY|2mY$2yFbpv4FVaaJzv{BjBv&B%8bG4)_V)${3&7ml%#~-g zHEIN3pfXf2HT=XJuYd3NFg0h8CxA`m*54^usHK!Cu>CltEDt+i?K^R@2EG4k2muj} zKQ~L4621evYPAq%hF9Tnjag>h&v0|bUI;1r2P*_=I8urkY-7w+$klzVDmW!q!6`Fu$Fd!kuu zzTOC*xn>FZh(jF${awh-{OSJSv*7J7VFDO&@d?;er{8ld3T;4!3o%G=O>L?6KV&RRqjl$0zMT)tO@a%wgzA<_F z3NAxY%AZ_?JFC8IQmuSc5i75e_xSwaXNZeiHcnC{S8`0K$`5n~1mkU6E95w$ExE zG6nY#C{t-=dd+#w?`1#fxXkx-T+Pzg+pKHuL>+k%+qFZkScDq8HeIn|K`8i-WfvPT^)W!aTk@y7kM*4q~OUn*JGvrx6!eU}? zI%;idn=9@T8De(a8R3{QdvsXjy{+v*@t+Qoy`OIs zD-XixfI>U-FQR0S9($2Mc}T)wa_{HjH6Bs@NNz~;~mHsWHcCxii zf%`#CA}ME%uAau0bC>@`qU`-iO0#okZ3j6Wp6_(D;1XO)A(S|;^By*>?O5j9_7*mU zU2oX(fRl`;IT)l$;0u_P4ufd|!-0}F`kgSY;~o}oZgQPn#j4m+9anT*g*dvtH-XkO zy!LtPe+oXFPeSf9s9J?m3*{f*=nJbh!nd}47#~W3DpG8li~8Osm$tv^aTeL$qJk9# z8dV12Qn+N9G~j}K$X$|6AZotCDel1T8+$LkJ)f_&*6JFY@n-Spvp$nj#Wg`dQFs4M z-`W&{k{`+s)7GI6Wx1;@7VkqXnAg^_ii10|x!D_erp~4o|0wcUdJ}2~jHb*qu)mZO zPrZ`@`^(qx0Rm3juezwMX!-j)9>$pAd5`Gw?=uA4FVEGv_LF-Jcmd3%^W^H#iJy{s1ETNzu*o4I`|RAcS!610tJ&6di?_;Q!+opTi>5}l&_AFminj2rKUdh z)c0qh67OY7oE+t}SsUy0`CBFhz%;`NP zSu?yS4oc&u-M?RzcJ-XfS6pElVqJP_mdY+SdZo{FslYHtJOvHCl?4$-S&#b(HuP1d zksa(EDs>=w1j?*_>RhQg&V!D4#Je`Vk$O}LY)EV;yTIw+g-SW^QM{)?Gbf$_tT-&A z_kM{Y7ALrcgvfUv)PA}y3GF!T2NODkXxnwb?iPMqj>3Dd*nQL+)_UC8PXLErdh_V0 z#ON*SNAomyuXBdA&WrFxmjjMeb?4qeX^l9b(F5LiErR1jY%pUdt2^nMC?~wICpt zABGJp4KF9`oh9meRoIm3wzbv}2I}8!+kPoK`ffc^&iyWJ;0?kc)GIZiQU{MZjGPJj zM??6YSVkV-I@#8Q4rD|(u2h7F)!iNVLI_S@BnQl_ETMZq`!glhsl^)BB{-spC+X!H74B?pw~Pm4IB^b~ z^j?ppfj%ZIfYmg>I3fcon`;c%Y7KZ~nISP^OXL)PUt>E{c#N`_3b zObL+DXgE7WKbAsCM{64mC9xPdB6N3QL0yVMA_RUNWRIsxa`oWXBR6D>v-c_@1$(gp zs>mu6sCo(vNUAiSD^PHS>^E-I!2vTz4djd-AuT(bIKSLGuTKpNm*2nKg;L__&!P;e z8saVFX_FZjY~yUD;LiTcdQ&}&k%hFjQhf5>xNg}6i;_H)4-)!T)Xo1#EDf9sTxsnt zwzHWE^6E+ZPd3{2EAw_A$oFKQrd8GPEj6110IViOk~;Tn%fjn13hY~b z<^!mBWo%OE8r`+ih*b`4tb3-@e9NkSQ6BGEVP6p_N`+D|&^t>p`T!2zAxR*OUM!nt zczS$$YD4@cr;XXy<->=U+7#c!I3mB877dA~Mm~?YtZ1gRMY?{T6KnE z@PQnh0TBX@E^R}W8>+aEtem>C0lz~D77b_q-aMg!J1?&G^P4Lr5Rwsni0lp?C}nP` z@{0s1!%Siu+7s_mH2daPgj3W+t?sKX=<0*FP5@-{xGjm;)!vKA1Yd>O!vs!`X9O$( zASPRz%|A@`gxz^h#d6Js)>;4L7ar|W^TTc3cg2;=6Wa;Sy04)rgD5$syUD5odkO3MLEMJDKmS0l+?+cSKst#o$dh&n_N++QLTr z#Y_NfG9B8zr%1tgxxT1(1t~_@a1U}qLT&I*VK_|)bsmuGE$jfUQMAlOOE>3lEbK&u zuumlIU6_(fPLr}Tj)XQIfw;6U(fE(sbInOhaifT+9HBx4+|~o}QY^_v6e&f@d!hoU zwgDuT$u!MAsGS!rIr5|lqNrTOp^wP5A48lFBLoE@yckGn2R`nkCl-FxBhYqo1sDb4 zo9mO!l6mVchFE!?tvF$l8;1U$3&=gRal>96($xU2yP%tP);L^(Gb;|7{RlgtvZNIt zxzuKQFH)f0w84cQ3{d?Rc!w_giD6=jDgv|dPgXV{mx z{u=8@X@*R8Sm$MbxB15N{-iGmHfhZlTM8Z|x9t7&AUWx?n5c{HK~gGeLt=(D7+}S4 zRoZmOd*Q*d>KiTuY5hvv9c6h~N5Wo*W6OZd9YM!aKtsgUq5z$Q%2S{j% zvLd>5fz-%N!u8HUDC=dVB&W0WK?hXn$Nm?gAWyxgIm6Eg8yfSM&}ivyndgYaF0;Dc zp0?2%;MmC+fPs7I;2#X;Gx8yDvflYF{~Dppk$lSd4Vv zT17Rct~6Wmyjt#Ez&b4Yg_R&P7j#C5%i6z&h(yx{OA+^)GyGk*ZFfVIBXJ)98hGAa z)_nc`$fMfL8yCc_c1W_sq0jT~M(1vnGb8gI(WiNL+vE`g?Ai z?^`!(DcfU1z;Krup0%|fcX)YPP!WBrx|^@hK6^2?K2C_^%NZ}uEt=AOh1x!Y7L~J8n@PA4Vedljq53xWqZY0)!BoEve>hYA1lN_r%1H6K3g<2`SP@6hDD6sgj3q~aw!HfianA{F0#3!}Dl)Nay#JT#s2GFHa(H&$LHTD-&OG~JKtR>^1fAaovf zQo-1sD|LOv*}<#y()H*uR||DB3X&zm>A-ri06B0eP$?t*Ldct8eNZ=%0HDi(DjJHseRCC za)~TzYQCDxlO@=cf-lOag_@=9UjFe8AQDVi2~3o@W_imqs~SQ(V9^s);$?5ocd^Rj zstzcMcU2I~i*c+M*Jt3Q6Rcp>3^`YTK7FMPst+=@ySz#qc179YE>8I(K_;izQzMxj zBbiZ-?kCOk?G(#rP+-GyBTY0-PIV8rrv?tUZl2oL*SV2yvv~bQ*onl(r;ybR5BQ~? z2elt7_Il<;i0^F`tp@emEGtmjqrlC`gaU{4kY9LUz!Yd4IV8+&kmUjZ+QU|iRN^c# zP?&LRgRG$vQ-L^Q#)uqc(8k;{n3dAQXHZ8BN&DuDqAkFF41YAMmv03a5_23Eg(UJ1 zv0R7hg>S79=#-}WIB7T6O3lI*gAVMz4_(e>>aZX!MPld)#M4tOogQb}AjKRW5)g~5 zhefVk)CRn^5xBx~Ui*s7H6io4U=eV=^ZppI^wU>ZOb0LV!YZv`!fpXXX0$$cp%OkA zfOuh-dRe&IeC7c%f<{5)d(7p*B0*2&Fp*GUCI58&lA0UT_3}#qpGOEh{y~$G6j2fK zT@Bus6@g`PZ6)fAH=W&jVM=?^nuT$)rsUTFe4o!gDYkR+Z!M-X(P9y8q)r$_%|+KG zEg_ON%azB5XT+^XhO~7Jrl&bEjo_I!7M6oSzk;iLkn^DhnNmg-s&x0(DO%$ z*nqY-)nf@tdKVHMIO-&YU?p_-yDBmevvQcc{b?wfO}3b*8>>qc2CA{BC?-_74rp7d z&ru`+*ND0IhFJjUoq`=%DKV=N`6SJTl`otDf_3e@i65LC2zLE#4@oAmquEq`w7HPO zPl+6>oAXI|Su3(SA_mMxr`?!orw&^nu%K4Mzz*vt51(@gQ4av=Db)K=^#;YDJ{Wox zL}B<9Q@nivRo^v%c^3+jeY{rCkgbNNqbMa9q2kXc&m%C9CriMLgOZ?z>cNfSHX<*N z5nW=WwDZBlU0M#n}CmGdy9|A);-j9H9<1B ztRi)WwrWj_H~meO*!$l87T7X0Ydwo^-Q8$ zlJ$2oG`S*TR} z^u5_)n(l3{PU_r}rS%2q{j*BQUlxb%yJ*|CTu;y?f_k#gp7Y&| zQ{fDb%MxZM7>P&w=H8E6YrZk~?rsf8F8|hdy!7Sjjjbo0dgedvoY6`7A)oWuve zHW}C-uWbNw7csOdem4#D`7gHyc!g6jp9JR{@^B7$b<(Qm3WlWX#V5I+`bmml}#7#whfGOmw|^Q?J>&Z`S&G|nEh$OF0L}kSA!}4 zHrDf`w3cl8HG1}g`*3$|s(^d#P==6h0)qxQy?SXJm^V%5b3UZ~h z`*8j)*4Kedp-1s`V)~ZXznlB7cK_KCX2H@MYBUh!{cG(0pRa+B!W!Vjycm_sHR0a| z`D1~<{$?hD0r(;DU`L9-yTrfx#-ANpdLV^}pI+Y{`KO9tvq94Bhu!?Lv-vLrJU{(E zM~uaQK2XWKw50KO`~J_(d=01pNel}K(MaTfRrFC@um1~~(l@F2_SMcnBzo8!fD_iz zthGw6_C_J#j03ncm9U3RITZ$9$yjQtbm{BErBC}MID80E&s~rQGqaLO+|KGxe~Jf!l#X|n`UEBla+*=Bw3alUh) zF%ei-N3)xa2|ohjl|lPHA{sCd0Wg~%ya37}&eNm{V{dR`#0Rz?m=KW`4MgUR%Fj350zXs)qB12ca79UnG*ym{-=|mIq|1oS)r>5Bk6pmQ-|%)XmM&KnpD#* zK#dF^4&?UI=)n&dodz-($k?ABI1eVMT+cVBO)~m%ELeBq!hI**^){8VWO2KC)$5LG z`lK{JW=PRkLGJDZn8Vzf2F!07cD_0STz6G*Z<$6_o?{mZ7q{hX>9E@nSR$Zkthr#X z@^MGo>t-uTvnuyD!NWXgaX{&$4&}51j>O1A2CV z-)c7p6Br%$X3H|Fz3Z8>o4+Vw%0Apj&RyYIAIe?LClxfj{}@uV>jjQi(>vVZh0w2V;YR$Zx2 z+&=-@3s;r$2bc#^b}g?Dp)02xnvs$x(KB9a_jMR}{oRQD-g#Ds2EY~L0BIGI^~2U+ z;^0`8s2EB1JCbaX{W9$aRwu1El=8eqz|XJ=QXe&+Mg0aqVOG!EbK0&-Mqo%oA}j%A zfF&g~!)t(Ch9#j*G7zUZ=4jthuynYAJj{`)P$3qe)CZW`?qiZr8FIJ)XtkRc_v@#< z*A2h2!(#@T>C4H50KN}P?D}EVB-I*oe-}53|5U;-{;-4#n0{{nE~7>44{4fB0SwG> zB22lE=YdH|1X(|gOd<#{X4Q*XalGpRSO&`X--zCz;ShTqzwOYa!{>7;*8a5nt)LK8 zCjbNp@5#)TlfHM}ob4bHmsYC6Z1|Q+Co-L`K0R{Dvl-R_z2=78XY9oz$iVKq&`CEI z1xn~K>@OOfmjM#~xn^mHx7sqZ+x0sRld@btz>s35{~LsQxy>tg0DypU1*uPIe(<_K z5qIU-#j+VYoKg ze^AA%78d9plQymhs`o1Q;{$B71=d70IPSc-?ay0_k!Gr(lSkwIMfwpS@`tsB9?q3{ zfb0)?xsz9MITdP)0bF4lsPfWNt%q9kNosne!mzn6?y}LFR!-9{3KGYYKCD0yxh}3G zbFu`&jkZQVcUb3SXLYQMvbM*@B_NqK@l1D~i3U;{_2c9cAgQDn``LPT$kuk`ivlj3 z)YpJaA#S804f}ca6P84W4`=*7&?s?+wy*BZ7i%#3i@0!rv8d+09%+nnMO05%H=v`&^Mk-8726RH^cj}p{&Lerw{Su#Wd!a zIS^Fl{1BDN2$`(`#B994bX%N*mXDvV6>{($QSLh+EsA-eR8cvC45^wgcL0Xn$z@I-&J=J@?9S7BKPJ0`!V6_MSkO9505P0PJM!5A z0PLj$?w*I6MwRDS-XA$J+u?2_oSCq@wev(}tE}p5gpZHAj&n{Ax{IWurreu*)wb=y zMLfRE+!U9l4v+_~@*Gr84)ORJ<|C*QdfKloISew60PQi~jUk$*<-DMGw#@Zd51yonNw%ak)4{GjCs@AxX%r1gf8AlobyJp?eC6>9P2h6 z=gRdZo$ZEaDJ4RA_1mr_BbAEH{KDVq9@dA9ii2I5uK^Wl4=>9|`1455yhaFkJxwR2 z5gZJe8Z*#>apko@Jj3hq`(^#dJ`_)}@ZJ&0)FDVm5mp3Rtd_eirlN~<=Y`sCe1(v} z*V2Hd^Rw`if}C-7y&cWk$d5bo&7&Dt0QfkqwNIR89V>}BFG!{+x!}lCRBF*yL!Tm+ zVsk}J_KG6lVshfHiE;XLb$DRrvKn#QQ-Op90}}fAMU-_9&N0}4O;zDiMT+LQdU(F9 zSow-m+)#zb0y4$rgF*zLi17IF{fOV@a}NBSP7kgMihx^BTUp@Kfl zuNOt?mVcN&VRQ63AaIDKZZzmYYyf7@*nvuoiiY@+WH#|HYnkL?Y}F<|7yRe=0KPIk zMUxM9AdyLjZjPkwclD&45+G#tuFw$Ae$m!0_Bx)vxXr6{-$t9<`;M`8u2a`&rel~| z;uWUql&0JiBK^KiCuY=8Dwq6J%wRX(6ho87)xLG+j8Y|_A`M*~9@6S-y9eTlBu=x` z$Y0Pk+Aj0PF*YOIqvyW5Q@q&Bev_z>crNv<8LlfIU+?$+VB8W;hzzlWt^2Lgal|+* zqnNj(pn+qTJ;1AI>;ZB;1!O2v!Kg&6E^MFg?orY~$OIog?NXh1+fI=yPbsknt>m~y zf8KXw(Mdhyz*Bugvqllu4W!=*pgzARu5b36y`cvw%UbAro$(kH?fziS`sI ze@Cp;G9lml#4Im(W3*6cpV4Lx7Xj+PEkd5XRf>}-8|`tj9+}ntf-pl5D4HutS7_t> z!O9#`-)MTb`Rb>~yHwRBynMd;(2wCzHFmf)Mrna#jzC>s=e%0YpdPjWBTUd+p1A01_4(V1CnI?7T%hpbbXh45VoE z-RMdn!ktI5hWhqmBwQOdOtQ`SB~;EPq_U2rm=3aRHWF>9uG>0ntT~<<0jIR=5adk8 z9;;0{YNbvqKpId>XZKQJMY@~DsY^H*#L}16cNSucW+wp3;9g3dx8-h>$Q3668(2He zoPG}?P|<)aTj7m3Es_4)?<~7AT)#$LK3EL#8$JZH=yQZ_8mGqgX0!4?IEKX%TxJ0= zakl*xm@exrgi~`bKkjOP?aAzohz0O=HTKr8-gJKN=;x()ZGNHkO5eH;D>Y@5?7qFUk^L91q&9f!k)ov_{l8Du1B-pFCPS2@Z6To!92i8!d_9M9o= zLa?v7Y~6g)TVdoS-d9cyM9BwOXZCnjI~h%p8P8o_0Saf^ZqWBn%0@@zTb|^6+l3o~ zVn{z3>_Hz}?zQ|<1DHZRl%8(OxAHkuvzZq``Ye;^GAn9As`v)d`ltm4K2+3LlTZe| z1tQdM)Vnk8e&u{F&LMr%VlDLL3#rUh&BVfsihfp{#U%av7SFF}ZM{ExoMI&1~CvST^x5gsdl^UImB_g$YV9ML4PZtfEr zIpca0nuNq2=gZNXWfCmQhv|j3bjBDZPdKPewYcw?RXSYzFu0>_SMq)$p?uRH6qxeB z&%#J)<$yJp8F}@pfLV{Q#>w@3D=mlV91yUiZRTV0?JmhXVOQ#fD;H^)i&|7Afu}=6 z>CBW5@PEs16qI1}qQ9d#Rsx_0o9)AEj*4#%p&VCZqG+LX%n>!RdFbd`d*k9H988rV zurr0KMKY<|Drh!ADJ@Fp>w;q+s;=?Fv)W%Uj`=Lr8W^6KEoTOn2UR056rM7c+u^gD zeIZ>*&_&N+Onqq`Qk7U_7BnV231;}xDD0wxYJz2()hG`y!3pn(avgGp-Ts7C_eyqq ztVBWn(~kHK6ASBaKtmzK32KT$q1}tCffmjH>3W@5eI_q8LdXF2T1i zSTi_$DwRumL-v|eJ8h}iQTROVy$yz%#Xv|{FGg-*nOE0yKbmVMy(N-`)zaVVCPL~}rw2iiWB}e@8m}Q0*enL#4 z0%wJw^(;;uvEVv&a`;aNPo_LeN5LH^B*pHk01o=r&5#y&X|-V|hVrifFTq>deyUpK zm#XU7L1C93*gnysUiM+evCaGGu~$hxKFYi;qxOt$sB#Q>$oqMOfLGk$pbG{+Q^lD5 zkr3Lu=xM~I`lT{y+=}d}`PH_oD~tobunr?7DD41p0MV~gH1rGISW}&H{TR{5pV8DN zUx!hqwYX-DX|Vxui;;6qKq~1{Xy?0!8#*R+sDe68#$A}Vc z<77q+_*eCZ`L(@nYZvtU3TTi^Y>G4Aim*6oysab`mzs@JhUc0rv~6M$AM0X}xauk|4e?=( z^u;bD93gt4Lf=sh!B=tyc0az-gJwiNxMj|JN^poMf^kjrOFqk_-Tn`p26&@`+Qku& zbh`ohG|9%)Pr5x`ipSuHaOPI-*~6hOh#1LsDJb8v6FoSw;KrAEzp+l5CIm=xf0OpS z@7A!G%!^c21T%%nZtAa7f=xV7s&a*OL;6gxy73F)JG$_TarKk8B4y18+<_HB7)G@h zNZ93yao@|C2n9Y>ulbHsbBjOgKsjkb2CMH5n~*Ra^9R-1rQj2^IZsTas;%qd$x2gX zJ1aM`ZRTprKi=GlM--^Umwg3Xtf%KUrg}V-z*Fz=F9$i!C22;=~&*7tBr&ozG+$ zfsd;h zkhfetwKk4;GP56Nyk)*Jpx;LzSVlSpb4!tTx;J&U*mMp+x0jNQq9z zx2Y~C73+Lmwsw~$rAKyvrzL{#O;ie~8NLXUfiY)Hgv+w zOIXKP^3!PPCMR%MO!XHU*)MMTuH?~*;ZAVFphddT24A3Ici7zoKw+E=3f}z0!t(wh zgi1`go)Y0^{{_(I4td-@k+~?>o^?YP%Ch78mdwR@0$TPNYkkif`tMaX$csYWLdnJM zRoOuosYqN#ziW))ynaTD{{+x9WEsXGKNbZRvA2L{MI2N}4IpBhLQ&3JybayFuel&F z)ktdklr*iMX1KGUPB^<`0Q9m?Od-qgMMm%pMdFdK;zI;>MyDEqu#hH@;Mm(c8i6N% zZloxz@|8RXv&}Q4O8}`kGoN1`@h7RHgcts}n5<;f&tCcVZ^C^whmM$ap!i>SI_e*l&93T2xyAfxskL#Z#o)lnR zm5!(HnM;SZ4`-ieCun-ML&weXA#Evhtt#{&Ey#dCl8sIzQ>IE1w1F_y!$GPQ>hNg7 zr&fDmuXE-dJQgfl5M#Y+Kl}OY85$B9fn^p!_nz_0g4s7Zl16UZSbB=-D4V*Py~=?8 zGfS){Jo6X|$ZUh~4?94r24$M=oL~j$R(LKDdXNmH*GpQtKB$3AA4#umTDY~} zu?S*HlA@;QqGO>L2ETrPf6=l}lkODbI?b?|TB;{%*uQzy zt~8QdAk$+EEN=Jmh3d@jnMw-J?X>ki{JhzQ^4v7@c4RdEYNQ!*L4ocK7R`E`Pj%mEcvhV)E$iVa} zZpN$Q3Q7{1PNWl4k{`fG45~^+C*t4q8dq+aueH^#PW{8gLN>Fm{Ipz-2IoILic zx~*ptMui))1>@UaYz)8>;xd5Nexa$a$*-#L6`~LPlGj;x=5DAj}o|5p(@WwAdYI3=B+Xvb(vb8AJ_XCgf{~$7rbsK z9Fx9R8}8W-_ow5007P-y$VQ;2HYpLt!l5Q)|6Ph^ws zogZ(LOD!3=`tt0~KgU~m9H+gYF%(_7PuMQ6FzDM_^BGi{l;&63XJcpoMnfK`@{h1xXW&?%b~{ruY)J>K62 zlM02XtS0|;1>^}jS(ESY2|Yn|z1oFip}F`njOsGmQ&v^lF1r&wMk~^jo*i6W!!7RjrP%$L|6i;M#oe$F5=XCbiuwn!FpiQ+r1Te<6XUMQB*Ra zI&9ITX%wc}p5zubPbn;M!DRFH@BanBC851EWuf}%V$awa;c_$!NGXC2h+LP}SK|_$ zRYCokZ~cdEQf+)vI4o1#KG+eAB+iYjdPHtrB3*fV1Yi8lqRoB28 zJT0#tb%fnXu9AyAS<^ti zjM_?&bFdLuSw23_P{yNCn<*4ovx!Qy%a)NquoBvlu1w-3ir?^?Jt{6IjCr$m&-zH0 z>BS&#d@4Ze-d!$)H=n zxT#;gd~5fq`J;F*zjIA|0;ARdpoF9w}t6yV#%Q!hK&J0AA)PX@tI`t^Yd6 z(O{w4eESF>thjt1^l^fP&&s!bdf<|*0gfd5ZI*~qx83y;fdAS1*cqD2Tt{wry;}^s zTZ-+m#SFcvor_pr%DVi{?UZU%0*ES)QCN){d!j(nk4jIE50k{IkyP{&PB2LP+Un;$ zJ~yQsP2opyn6m4<7}WhMk(S@oEoMs@UlELJbsInPQFKtMnPLou`1Y*ey!JmfV^Ii1 zJKD?^4ZCPD&?Y@%+-Gde5D|#vuIWaW0;EU{Elhpl80uj_K;TA&>` z$nd~Q28u5@csAuBM-pVd+6c%5oD2o84W`MU*`8fvSUOnq=@q~&^j~5DT4*RRA}YDD zxjYRlOdSCJ`30Dx^o@X*Nnw{X8U32Bas{kne-r8vgV^vq62v5Y5U9hBGwr1Qs|EjI z)GbM1tZaQ;YMA|}Gk`J%8p-)-(=eYij{hOC_y|T?`9Jerr94O2|KEQ*k_rF+c>C^f zp!?`;NkyVi$yOmVd(UjLx2!Ttwv;W3vNy>d*?VshvPB|$kIalRtM`11>iPYi=X(En zulKtA(e*s#yFQ-J+F(T)qndre;N50|CdK; ztK8sQ9l8JS*C2PlLIIsz9Mbc0-v4!@e8`Q)y^VE}{pZU0{l*lYB68@4Ck-@+9QxmW z!mh!Mwr!T%=1CCbbBi6>%Sbz;H%pz-0U)6f*dE;5t^*Y1{(THc z&*GJl(`3*?PH+L_Gt!wDj~G56ZU5d{7pDvu$$#I~f#9)`ACqHd5NJZZu!!2gO@tbHb#Z0|O&5RKeSkpc-YO&9 zoROGuRoa8KrcAXQ+LHSN5zYQY?MT&qRU#nl)y>FT`|J1u0J7bkWXtrNfzZ1N$~B@- z9}AL4tlnU4fcRjqDnFM|w^}2&?zllKZg~fVG zRI@xMk%MM9uO`?!RJFMGctuP_aI^{u#|~4YLEDq_i*v7V5myT`oqX0oeJE~tTf8DGosJe zwr;vG_q}@-WvQ&xe3-ENltwfI*vQoGEz-FwtpW58(sSJvL+ng89iRnR)fpmywCN$q zi6oubehzLdSrymB1T`e;KxAHwSRQDUQPQSqAbQ5Fm0`||lVNv3x=>pU+MH&h^5@y$2K&+^^QSz4{=k|MZ-@X;vkZ*YpFOngKBzWN-&q+dks7X{Gao7r zMocEO!Zez1&O%qb6-1EN7LNO{Q#XB{;*5Pg%AT#a)dS*(ryeUHePC{T%}_SN(z61| zBvhgA4qthFoYQJW^EY$b_d{QLuwTM|BEY%LEe zS6kv*8+Z_9Kk)cS9Bo$Uoe|+}(`^^-89oFv*ipGo(~A>omFvw!{=&>Miq9MwZWRYX z8Zpk!w%womOYIjV#C2?hs6n;w!U*RT4y_kS*k8_d=V+B#c-S|Amd`xhqkV;S8Lz;5 z4P;A=b0BLW3iJ`Z5r7GT#u`4T3^87j3Fis*%w2eT$NEHWH8|AVY5Vq=avWsHv!GhA zjb0B0FBv2|Hk}C@g*bYi^O)tq*>d^`aq3v|`ZvIcnPgrrI6k8I1~F*DxXD8i_V|@H zt$Rt_lJsLokS~~JE=mn6(QJj#+8osq$a^5CJXZah|Hr{5o?X!GNtj)irLMpw6254q zcQUI4Jg3MrB;WhwWi%3V$F-I-9lR~gw$L17L4iu2~7cbBz3pPGGz+I8xA)}Iura18C- z^(I-3lj;f-r`u<1&uD7chMRuKY;7j8jJX^4$;CA3hR+GZLc7_tcvlDlQN!Wa)9-c+n`tjjuodft5;A zd?3=yIP#*xgZ%YR;BlB?-F+xO1Dc^cDS8E+A}sx`yIs8cjYzmO)#g*qXLMs`CDgL>RF}o{Ju%_JClIAm{X5&+2+e&?h#prx$jA;X0j$*DyT(1}l}habHzhg{=b9!6 z;0lymlvs=kQR*GjHrs|zFM6-e^_xupph~?rUUkL2pRmYq|CG4A>^g8G1K$`rj{kw2=7BN037|^E+ZC7EP1_u4nBa zqcP5$@;W^|(S4zVfKxra?7bXsMbkY&i2&lL(g3Yv(!#V$+Jrw_7(W3R^(oZyw9V72 z&o(XtGZYRQN}3*a;q(bH23&z^FxT=ffBW3&Sw7uTc9C3?MJc9air1Nv=QZ;+@lR1d zKR%8vxrpsk4;;x0{N~T(?h|8QJ@Iv!+$?5KViC3glT8mQ!la4s;gD>C1YxJo1~NFo zE1*Vkf8awivh@IDhZdmxL0r&^;x31XPk5cOBq4KJW`+)=tqd7rm4p1ul*`t3a}D2SftLlT`Bsy{ zMc*L$;uz+7s$=A=X!9IUxZ7i>_?Pj1{i&n{nfUso1t=tA&f9%tMI#;q5n+d%V^Krp z^?U`CJsnIIN>onuMvN#<%Q$bBZRj;4gJIV|Jcn=j_)fKF6UVYIr-s!>E_-EPj;Mor z)4;#evgqp{8?2%GS$)q!}rZ%QQo9?=<6XKU|c!)gdxYA$EIw#Dkz)2k*Ymw5jLlFEL9}$-WFz zNT?^dXo+37x0n@Uj8FiGeD@ya0BR%tE^thS&3*)_X9^ap2X$--MTzQ3j`iCr(Cc@# zoH?nE)|hX5V)`s8-$$L~pNOE0@3SisRSn;#s1%{jmd@ z;%F{wWGbMhMCY60u02g&4s<|0^s zyn^_11K}8{)ik*CD39tpfC^%!4!i{QW07FHjzoFVfMtj}325lAr8{Li9+$naVr>y4 z1Vgr#F3%mC78Xo*I-$+EPTNl`(;WFjRj@DdzTgkb&C<{Q5_G>AFwlKa*e-yq6&WOL|fP?JW`m=)0VgabN(;-Oh5I$l3LRK zCI0=j>skU{BOv4MXBb^IHX4k$d&&9^kagZOcMzs|BJa`(i#H%BzbWItorG~-o80{) z3@YIhNk0U2j>2iQAaeb9m#@sqO z%mw!8QTgLMyZ-*0s2kXmq(kOwc;J7IfkCUZ!|+l+yu%g`sP;8q>#M@;wB10_%nri2 zh*U74v$T=?ywInbg(fGPTSfB)_d@K=d!v;lLKvsYhM4NLWJb_*Z9Cb$Y13kCq^}rD zCIul2>-Uc)-sAa309|-1UT4X!e{Pe!5NZzJU;}nU8|)<7NNpSM+x$L#Ia-{-=dr$> zAz-1P6T5ZYAl|A0gKCxK>o%}Im@B9kpMi>jgE4bd8hjz~ZuM8N?MR`nY z(rueAsHenkhRO>>ZBV@*ulK$dM-5JKQDCz7M zS0!U<<+$J9vAn6KWp2bL0i2(){*G)MAGcK001OXT0CTVWwcQ`?paSQ)D`l&>cInT- zaSvQr&QF$MbAb~5fhH7U1!_-jnu%drEC3I4&9df8M7B+&n0>K98; zfPtK0q=3Bh0e2QkX&HKxLX?|=v`Z|Pn|OXdWw<}eYSCww2cXxx*ytC{Y~VH6F0QhmeMd4SAwBnX?FH6j^n`3@r% zbDJREj{sB278sto*OnmpR2ZAn9E8Q3oB;r!A|YF-Ebwt(`onsJ;khFf>*r@qMet8F z!Ff+sv-ocC>m=;lTqyXNF=%evJ=>Q$Xu6|1&*J4t)R6PK61%O4*`OIcgd}4_(OsyGK_iUl+SxaIPz~l=?6$Z_TcRcC&@qC z_Rl2{!r%&0iUc2Ty7ynRM=HoSR{d_!!T(=Zp~iu~5#HiH#r9t>gjNucDeVvDInw{S z3alR5CFvO<>wk`T{S|G(;~|WN|F~p(_P?$o{S^KNrA-2t?7zPa2!HYrncLj_DxUr0 zCH!?2N@4gLxZDaczw{6Pn}8FVy#`okR(calfF4zE@l_m`6E1v)~;rlx@e`-G!5M08}TN1^L4z6V~+gbPJC@_OKX2FxEr zqD7-j0=T=TzD0&eNMW|#-2)r^R^uhXdT<8mmnf{~6U>BCjF{|bqqsD=qe(hFVxNuSdrAu-}vEbo8P=YB{y zYff(9{!jXx624H|M!s#@e}$2MI%53G+I25Pc>e#}qx>Zf|HsDowd)ijkQ;4((-r$m zQ2DjW%QzYqb^fZWFJOM~UVL9%B41uuV2cz1GV7 z$*BE`EPwy0hP$k-wRU0qQSh?f3y(b$k9E4dCX(F6zh6XAAr`jZG80GZ%5b`h+Pomm z(?V~(7?W2&@;ZY-tWe|IF^&0#uK)V-k_p+!PUwFY&->Fy``;Fr3|gC-<~CS5w*vur zO@i3^DcikXQDg|Vgnh0VEU&Ku>{fek<0MoBLle>vUD#B#bTVK*bh&!7i^X}zN;A4a zekt)8X$Z(#?rZqpGWLTZ-d6CN@%_X%q5b#0u^?;e%#8xVqpc0{CFJU$fos2)7E6U# zX!CE6C?k%S$e6A(n7|+oBGTKdqwI2uHlkY)Bc%{6FF2oyw+`g#32K&@hM&2_B;>fY zaEBiSU)#!kYjsqs_Q!W8*0J@sWQ=4kTfq+iiXt}Q_lb};&u@Hnp7Ge2#KpCO$&SN& zOkg>eo^(3H&|k&qW7e6kcQ$5PH^BU*6`n2A9sfs<=69IVZ9Z)QWH18glV6KGYKb_3 zSq*)uHs}#u1sa|unf(3;Aa6#PstZLXSf%crqN|3|O$(}NdKj4t0}tYl#Yoqu7!0%h z1Rt$x--F&9P4_I>UhwvqZ^Uh3hBLMosa~`Y=!8+6ReUlnPC&#R^n7WXyn%S zg#3|mYWvPJ2em?NLl19y_ixJlZEK5yHRp|K+P=wzm3|yPT6i=`UHpqt&71t3pnTHm z`&!Xi`9P}zgGykJI?k!Ztytw%J4ITQNd3trj=Usn5_4$b^`^;_y`Klbd`DpYtWOmT z>b6o(Ym1SzA=Wt+!08lFUlsRbieIe@zAEvu33^+U=k2KPCbV&D%;soUC?V;qj7RYA zNGtsqW$o(C336+n!$qsmfGm~B#TTl1x}#>fa}HlTIMag=S9xd0ANA}P*OzWfojMUi zmoWs6n>{MFOXcQJAO2bK3YAFqt$Z86dWbZT5zCw_*F1JCZbr~Ew-RedzVjg#3@w0x zqpLu==4ULbcDO)yG}R1@CC)xL{z?~&OJ|`>3uL zK&feLTl&S^1-#{d@3^*bc!w}kN#s?=%=*_&B7YqlIf0%k92$VdkH3TyR~ zV`P=vzTcG_;BV}f{c?9j-d;j6Mq_Sv74)cPAv8r=*6xZUvvde2$!0O~Q5r*v04cVC z_lnu2J`z_ME;i^E0kTm>22a zsZ4xG0fHiR41q+*l>Y(nGsqZuG|QSGHPhwNOZH=qKWDyw;bCx~c)c$>R zez~il>8tmP8#I{hMYh6V{?!54Cv zL$|rTA4a2W2-exXU@4>td;#vTa!C}^9gmFpekryopfHa@WEc&^#|?o(8F`mgj>G>< zf%9BPvurj(iwY~K-KVpVA=0(Cy7G(;{vC!D@wz>ZM>fmLGH{miPC=m~z zDu3cg^$|woUQO|W`o#0I{98dtySNbO+!r4mC9h9L0_QKeHdZ6M4rDw9;;<}q@6^C3 zDZk^3@`p%Y)EU{$2jCYfwTK$=5M(oP1#F^oTE#k$L-?SfM&&gQlDo^7o*dyKWPd8~ zAaR4|tq&`3lL-8BSE5izY^!YY#m8iEhUnGXmI;%~{Tt5buQU@ILetgbg&oHrG(vFO zc+A#@O0ue<{f&CF0*2_sZ>tW#z-zJ}FLW0gBZ4spjS|zVU>Mm56awHzuZq||%r;;$ zhP}b&?ckt4JPhO4hIWSgXx*0_r>O^q?&qzf7o%lz(e92_fwK5Rqh~blBa61g1g>O5 z36$&A8)Hr^yj8}nC2LiT4*bz2ar{LYB@*s+ zv$f2tkhvliR={i(ww_$v^f7ZSfI@NT{@qT=B(LgVinfufSi~aJ)Z>xDKiEWf!)Lhh z-n1&$>pWO^*Q-5k+#RsheyhGHcjV2w0E^pO&WQaMYR33e*nXw9i%qy<7v19b&lX*4 zkd;$@XEgx7ZOt^r-QX_qtSgnYmP}31cB10EZi|yox^1B2&}Zjv|J!}j7fW0$ims2U z=~Sj$h|eYmPE|bf+xZwxIs*b#zY~L&nW{IMRToa)x$$V|UYZoFwGJNKo3AO6X3t+W z>vdSgl@xz+fr1>`m7)0*2|Jk3YTH}NV}z`H9t}%8rWA2`nk7E}NW|+5PPtD*NthBj z>RNT^{_A0jXeTZ6ir7jViQpAN&ub^7hAne$-OEbJ%y*O(WzABFop;c>JRWj>uTvxO zZTDs(){4NQt3-)~uYT-ByVXR6w%a0?>y}>mM0WTrwX6#imIpku;}~}~VCYH*hX}C{ z$1hjd$9WoD%wluK#iEO0m`k8z zZN#RvZ$l07TH4gue6?1;Y&3@~7&Pghw_C0MzV9Zl`+WT&oUGIK(-HU7*@v8)|D+k7 zZU^&W{7DtTLSx}}poO@joJ3lw(`F}$lq1)ug^3qO}^Gv37J$tiW`X` zt&QGzNNa(q!)<5sl*3$<3F6Lk!SRgi_g#|;`Jx0;a7F1AolgA#G9*N`2fhA4KT`f% z+Uh!wlrY&-PmVl%^Bn-;q#Bm90GI&wR@&Xqc*WeyBo56`!-Ji5cgLS53%~aBG2ESJ zkSE3L$#%PUUsE3LEoyvJ2YDvzmdnptj3RZY>JB!1k?gp)y~fxru%EpQqK)2a2V11C z1|XHp2FwUls5haDehWmjbcp&AgK;8rABg6c6|o@uU;&2y1_||iHJs*Dg}4mWh#6Nv ze+K;f5r1DQ_h!%un6pgeg3cSt@T)D4eGQ3Pf^PoZ+(1WIfVHbnOimmfB?SaxdJ=(V ztn=y|$Xc#L?;=X-N&%dL5|s#@H|`3Mx%fBffFZIWc){9$N6x%0p?mE$`8YHEc@hvY zDSZenJstrb7+OG>`jPOeA+~Zu4TZUoK5SrspzX<7flCYBYD{1tdV{YP#8}a0<$&|7 zlk$MEez^`EE4tf3SqWqgYy45M_fOq@)GJj{22)T-vD0XlUr19_KWBr&m*#xE^J)pr z1LJPtKr#&7r_GAnyrviPsO#7i=Sh!gwQ@4dJgQNdgvOE58m8#4Sdv+#_h56*#Smje zVzt_Kb>6?WFMz_an^cO=x}<1jzPqI3WHdvOP2O73F&Y};j+0tS&*+#3)88`mC}LyI z+MgXte6nb)B5=S2sI|JPUid~C5V<~C>2GEx9rxgqMRVCH%%~P8VcUKLjgG%vw7TN6 z4VdE>24&O}F$Zw|=`hLtAs#?{?KPgw_8Fq7J2aoR+x@_H=$|%JW(;JMF z&~=o9#vEzI*P;8J^U#I|N_>eQU7PEmx4=L`d8U8ea{g}*E;@*6KEKGi0n20C(LYgN zLJS64i6Zm6$0gePUspm|szG8THF1##4g|Xgu=Y&yOqwqkHMPK$j2TbaDDh?Pli#a* zb9H#}4)Cqlw09vN&wzx2U5fG?emA3qFcxQzms^$FXnEBhrrl`yC9O8ipHu9=_U$W4 zKvDX_uRaSDp6}M}`f2I{&pmZH2)~_&!dIi#xx#ym!=U+0lz>ZR3^a2J%?4QrfO_C# z?N(+)1WeXFD}5QE#0ZvE1ku60spZRFouHI3{f?0M$ze~c=(Eui6Ksv3(PxY^sCS5T zOFQZRcw_;~0_X?_;L&Ubkd}uz8Vx}Si0aF`XYOoQHQ-Ak$^o!=#Kogd^+X)c46{6E z$DwF5C>%(Q9{*ABh{{Uoy!vssb{7l};w$Ox8RtYEQ&K17CtM+?6-9lKaLnVRN4?O* zCd`WY5#Bg+&bt}qB|17OIR+XQ^=Z*NSV=7x+|h8(bI)o-EG?4@X;8DpiZO|qU8!Z8As0+e2qX7K?5?(3@xWR=+ z?b61^#sOO%;uE(~PvM=2rJLM6;QRdE(&U)y0GR-`r-S9NC9%x;xw}ny*9FIy)|bR4 zBc49-ZkQ#>;n-@Scz@NJ3CEmpo(TUdkz1`pvXRyI@Xrh%x%@YBdc+d+iDh6gu2yLi;N|@Ojo0AxPWBTFpM;rnjiZ~=roH2*Q5)OxIoNrM`!y2$%0gz?5 zm;c~XuY%r4xa}%qa^h1~PK)YjMwVC;bRiVfbB|Hb$j^zQE(Bl6bzNBW4c>h1|7cSz zb@R-zU;kUk22Zku15TI84hq2ukhpMKM^UZMBlkIigJvN*vU>4+K_|MIu9)oAN?MCm zs^8>etU3y^sboTGPV*r{L#U*QZ>r9C!gWx7{?4YCD2^%Uj)nzsQxC9z-9Qkd`_oHH z8#yKDM(YVZcXGD)4 z0)F}Vd+;1$j`nwne_f0G9tK4tQFgCAh!2gxTzv2~Y~dlkv&39XTGaw6T+#I)m2Dzk z_x7{j>XDd*BI-Q1gS9{7P8WNZ z7;YSg`7J}0M4}})AJ5{99elPh#CW5xLex-FRyG`*v!WM(8`gqj+3t5uWq}UZTdT;} zKfOTi3E}slQuF((qk?;UphF}ObggEy#U5XQl4qem0L&O#;ov@`G)=NQclg?o%Xi@l z1)fh#T&Wt>+5@nl3#AF1hVL&gIBJ^5MF?<2!GM_ThO-@Y0-EDa@;n^Ac)}V z{Afd_$=U}H@q)*Uc~_P46TM<4sHF_jquK6-Ak~yok|>5^UU{;(e+XD#NZ?&Q{Jii! z&>~PAX9g1`0{PlKcn?=r5PC|Cjc^c;_Q+nQmi_E>Amy7>0}d@~%ZjE{<6tkWm6>ggznfZMUZwgbRQqH1 zwqs$^{c(ZL-8-@rSKll~vujq#R~;88zwVe)Bo@|dC(tFAkE6!LO<%3pdldy*>iVUKy z8?MDSN<5QkY2%5$A=ZBi)v|7-rq`G8>-DUVXJ?07;|D{QGlEP%@Of7Vzhdxr={ok@ zLa{*a=I%j%NqQZ~k;`(NX@1a$r!@*PboNiord6Ad7Sz=j-rUK!Jw*F<(5V3r7O^|p zIjaEAbz(EalHN>Jesyy?jn*g1m7jU&KJalKFwgBLYmV8S))I-&S-(2h<%K~kNc#1) zdGEQ;%8{Ez79`}uZaN3VGlmvbWFu4NAHS^CsR=G9UOU=}KUd_R#p=W4F02_>1P)ma zUk6P9!3%^y3&1B151*iqXc;m94nZgDpMNX#7$ZUAr2XelJcqk20S(t8`)BzG)uPQz z3l7b!S*&u}8Aty!#Ig8NPyY%yX0TD~N!=XJ{)#f_wC*u#MO>lfPzKH}^dtj{Jx{=8 zbz4P5CnhU}>vy0+an~0@;V-;&=86syw~$*0MJR?~k;QCVIHHe7M9)4kiK{R*6_J%! zdkYzJrqQddR~4-4=Kl4%6wwnTd=ezYqb`XaZX+LV_oqZlZCs>` zjD=RdagyJ=`8NHgIlcvH!;v92h~`Y*ZBN-iftc0UkGTO#PzJ3pQ2&P`&3Ncpy#i4X z%Yy}l*0#2g$v4k#j@P;r&DxQP^IZS54B{aPq9o$} zA~5HLC?LQ)(2u!`z3<-3h;^{>iZX)(RAZt?wZFs&V3{7+V%hCwH*@QRrO+bUfaTCzTBkW zbAFce7s!BLMsfAnmgH88VN6*u!6EdqM78GZx69{WU?1bB7joczw8~b|cRNi>(@snW zZ21hKHtq#@hZ;uSbt5b7+k=JE^%}z2y9PAT)qh@G)UOx!-uYF|(Tl^8EPvuH%Fc%~ zi;TTEf;}?)wh)lSw8Mdq7c-l6xaV_UU{c0ZC|VW}jySKfmmL5sTa#hd{wrUH1fe(@ z;$joZdi{Dm=M|B=CE>Pdtjt=cev^uWz;eXBI_AbU!fhuZr-|A6KWv5>?Pohs^NDun zX2Aqdt9pUH7^L_BZcNL!0VLcnryXDVKRl@$(D4U8xucwW{tWw4Ad8tovc%9kk&8!f zjq=Y_6{Q=hoBS1ogB~zJrUtUbqT*zrz8n4ChOzB(gAe@ z+LT51(nv)Sawu)z@x`FaBBEzccP-(|eTmxxb-GkEhk3`Ym1+JGj5jpg$V)K8--t<7}cw`&~j9Q=UF;88LxzW9=klYV1r6@=4X5Zt5yf6eN zLG-@~_5KX>=2%JZj61FID;}Q$ewh&nP~r5wws5*=6ZWB3X=zA8cm<|V2<86%tpD2k zGZG=Cu<18UxtX0~Je1Gay@}aVY9|V0MNd*(vjC= z8TfO{h}6j0AhSm4)r-RU7VT7-fOJ#|uKBBbgl3D6#*VAiTRQkQ$A^5Z zKG0MN7wOYvGrMpQR+clIIUQits2*u<&DQj)Lbb-KDM?b6MD*sa@^hgPh(S}lF-Dd^<7^JQ^oAqa!L-}87c>u9OvxP@9@ zIHR7fDUY62rE0`&2^51^&qN%9*zMM51zmW&es*#B-s932n|Um!BeIn)E&0r~BuY>j zYfmLf46$Ju#F{u>TH;FYG5F{UA2 zTLj&_X0-~Ty9p)PMYM_bb9E20Azd=r51`iV{Z%@OL~BB3e>XR;Fy? z-XR}a0&QX-4=>UlL-eO#&S%Z1vEA>Np^|wK^ya;JdA{SiDQIF4s#XNhwL=au1b@`vD=fY1HtbL~A!D5*FuoZwODxU{qSp#I z)f}nVY@pLI7jzl*DEy)eB`n}UpxiSM#XL-it&-%kVR#`0lk0bjzF*0_1L$k+hzuOdlK-k4#O1m7fTZokXXR_4DM!e5mmC?R0kzLdwj4$V23c$p z9tyy3_+=!Q>n~Xp6YuUAv_8K^`=smi6(G3=(|&-_)>Rokn|i+FlzMVS|x zRdl0Rn2ow4D#BjHj`s5h`d9mB3x8G~D_m99|I|P~O!0iBzTZ~MtYo&sd|I|moWtbi zY0hSO%iL1_61&2?T6KCG0OO(pqx{ZfS*q2&l&|rWeEne@HBVy5{)wIILFwr(Z2T;% zNzbvS`+c9Gyb7lqZq>4X!f)LxM^^BG#;kV0yiWbwLH9uF_`}d9-It}qtY>avMEu~O zq|E&b!4WK&=NEK5{u`-fROc2#qx7weB0x2e?54>?>Bj*iCSmEj_muV#P=Dbdr)Yy$ zSoC*QPp(i63IAf=dJ5-Jy2qntGZ=+~d=N#>1`NHoK=nkp7?~@x55tr*De`lU>@U8} zGP(sr+gLl!<*TM97~b@4&R=BIH5-DX_Wlkd*JJ@E|(YUi09@n?&(Suq2x=w4Imh(J5PeH|c(gl&Jv% zIfs_mhuaUK?|44@HP>7($0*p94@zgN=f!ER_roZk!U1Hk4i`A>PsDmvmd4~}QX8U4 z=CB;2mwm~`tfgHBu`5QpFIQ){ldzHl%$`OpN6MMcM-KQ8t`btd0TMbzD$G9V_huHd zCPyHgCh;7z%i$(%6egdfTS7e0`5m{IPV(reZimJnP-yMvmQS5FSa1aTwy+7m`3>+V zGw>8?%vTn7t&u|KU<$wP9KVyOTfxChsJw9JyP`$5g&Pe$9z=DP;{u_?_rKpY!aemn zAb7|?d3f8=BxoMp%pI`19n6u zkkf|-Gs~9?+eN#hl;!gm>+N8}C@EY`T1pd9$7;qs{Zg>(Nv_T*yn;R{-QP7=&=~;k zo1~OmE!De0r+?jsq#ANTBs~@t%3;}6cw&;s%Rm9w!cg&U8gR4OqITtjd+8RBW@D=K zuZLO9iPtolT5VPgjBHe`Thy{PS^WXS=eR^r_&2t?0$MQKNeOmhBaI zkT@vuI`8CLVpAotXG8v{x>eR(If;yhmi2VK;9cEe94|ETT^O`O+8n#IYT48vv&LC* zV-uMH3pknDyCxGAb~b%`0j$esg>DqX-*h}7Drf2WCswqG6J0PlFx~Lzb}+y@H|QsFIT-uTZzzqg*b(3 z42ltrv;e~I>w#`9r^uYog{ROGA@w;zBFW0Dq!6ClULFih-udx;yABDDFLz)%&a>AE z>{ry#yJEQC$JEVJ85$cS;zvtT#I9I8I;$!ch|z*V7?W3c$&>%$C$P&1#YFv$>K+Tp zL43cm5X?P)g86VCPsx!_VpFYx^}phAb!eB&s@Ls>uNP-3{Pz6Y265kKG5g3w1F@nz^^@M9;fZgoYfg&F1>Z)?fraJYUSpta$A8U zRjau~)%@VH;1A;aG?tp%G~S(B#UCz-q4P@I-y@WZ+tRJ^&)&W@ebr$~)VXavIZV5@ z_u#H{Vm};U`f~S0YbT~tB8OMo89?M$awSUdPP(LuAG{;6z0nmosa9`oGtb?ZKB?J< z&VzX@Te!j|u1P+o_`QUg-RbBxyWzDDjX&n2_wzq)@op<&Uh^!)$DW^3@tv)cmGVw9 zc+2)4s)d|v*2G$Tp8~o2ypzs_i>5s6HY;N_!}hW@Q$BPvFv-snKp@O8MMZ^rX^E-7 ziS5ovM`?T<#|B5?kgMPyw{kbEzdMZHIkP1@p}F56W8am(;d34QrTG@LhAe zSh*;p-ibvfplwmTF232O=R?elV^#XOZ8q{PMF_*DrEBO4bLuIR@B~>d#fnV-a$@c& zosqBYV*_Vh+f|ir9DIIY;?r3^Xj7CiY+k1~0F0ns$CP3o!=q^K`&WqqYxh;X;(xU4 zewa_vW^^>X>?hacGex~LMpwBeuvVz9`$?xim`|xoE|AaawbPfblSElo8~Ne(N<{J= zd*7xW)tN7reUMl-u8J*|ty?k0E-fw!?xm%e4s2d`jn`Q9Y>)pzmsfoS6r!@<30nWg zZ|+R62k$VIAB7)srk}y))agr0sI2h|)W3J@b0YZS<%uceN3kU2xln_@o#1Y(yde~q z$UdiN5LP}!NXT|R03zYjHoAh~(eH3QCs|&w3Jz=GkcgyU*CLQ_0|pECd*m0N{)PKk z5ZuRcn+^XEgSnR@|1vjbT1~s=b*^Snv%PE0ga>9yitVl2Gg&GbPxmWZPJ&&|y^yIG z#}@-Phc6Kv0?$xHp6z-1LQ2#$i)+JEjOJl;L)~o$=vn^ka;s#7tb0Vq#aWz*wBY%h z$%}!sBKGd8Z5O8gVOnj96>4BGRM)W!XMTQxJqY?Ekg+wEmz9l2H`D)5Fq(JtjD(f&fMZk@h$;ON$VWAbxMd;v6A| ziZ2YBY^=}D;QWe*KNA{6$c5nQvxoO_gCfYL4dceLy{?M|xR4$!>HWZE?;5qA_NQ7p z_O{IFkUC8A;vOB!y8M29cy#SZTq_Ki@or0tvS$E#IWI(R^jdX$reQU- z?nxro9j426b;1gcY%Ya>O@}*8vTwMGB zGp$kRoFosw&6|_0JKwi!0y|x{9U^Nsq>-xx+dOnkNgi{;?y5sktN*$=sy3Fn9?@{) zo9lYjIL8;(ar$d6C{3}lerlK7s3fvBPK)eRE)fG;*q`l3nOL3=Gm~eJmhrjA(32Ailv!O*{&JYMl91AyijkKOLh3P+1V3*Qg# zUfsm!q`8u(JnBnJx-MT-2g@wu2fsVQX`9CP%i>hi5~agw z2d73)aq7;pH@+NaxK?)m7Xvs;CBgUm-2pgT2tM(icavzG?@J4-pya&z$xla%{X8kc z;RDl%Tf%ADr<8^XbGK{$)@ER~r?`2+cm`zGmYMrajIKDgCAsm~11OsBQ1v_Y-r2PN zL7*YAK-rC5eNGz1Nv3pEuSt_nHdjJ{=-!Uku+h=^znw)h;y~)bIMk3MvK`0wxV3;eA-% z^fVWEHJ0t3uGg4GQz!ovpycZC(zv^ox-MKi{47AFVjUazx07B>^1Y60XSb%_($Nuy zoS()Op}?5mBMicGnJr)gZ7CRaDpjz_1n8mV5e`MLSMO$sX8vvW*7-RMj3dGi!8Z!G zkuM`ICjL9XKNSjQ>PMU_lxm$F#)CRx82sz%=>5_{wzaVe!G;ZP?;x!4PY+h-8x{YiM}!1Vy-&OvgN zhS9u`+=`4yt6YsB%mH^NQ@;AlC*Gfn#!iR2!d4c=@;ZS9NH`SPAtyR5Hoh&J2;#Zm z9Ns#zp-v;MZ$E6P13z<%$o=}LP#Ge1OYyz;KF8CNMxRNCiw?nYhvFu_xdT=y+Nya; z;;Wwp=&Y?DDm!5Z)O@7<*cMr6zdhI=8>0W=OUlLTn!R;%!vWE5}mF!OGR*IeQd!K)_HBP=#Q5J#G6i=q2Dv?zB~c`rgM;cC^mT?fE(!6FCcCn0-tk z-g!gH)*{x?i}Ek!gV1YA06QTQhb8D^p&)dqsmdOFQbNqORC+d1c8|rj(2pC>P8C69kH5v04}<4@)x5ZRAWYP0`TC=0!Yo%)FS-D;!a<{ zlt~NPgrDB05j`=aoHyg14T_7qzB%80w)^j;EQAJq7M!h?=Th}of4wDg6g1pgEH(9x z5-eW(hhR=n2*fxX%YEvVUY+a70$P30b%cTDHdP@Q+!RKeHSEV6=A~eHr$B-G(`((V z+XW<`e#|}Q$;0_%2NS2L^7fZ?8Npm?n`2nj4Tls!4f9FF5G^D;oU*d;GdOuL(P6@5 zcbD0&0}^2liN*h2K|y)TYc=s$C0Cof@gZofjg4-B<(RPb`2N_oA&iSB&Nq|~R@kN? zzVAvfkPf@pzb{*pRVAnN71*E*p7F8JhetAkEVXJ7Hvor|5sCbi@j>;cRI+$4P#xj= zi;LM5MRZ{{hLW9~z1v3e58}K*@tBlb4rC%+o1cCKc{D(KW%WS3X+gNgojQP7{ue8$ zrr9=ckeWX%=Z8&MKB_JfrrM<5;vP};FJ2e~_Z)p(5bXb? zvk2$Cut`jb!P$$4y^^Wm_v#~`_=Po(gGOcl=YKsXiHgD5uLEjrq}yOC9!P&4$}^*XG<2nA3{-%92jVe9UG^!-Gr}qX*$U< zOEp>B&>p(pVtSJ|NW>Q=>tKdmk$MlrErl-o4;9mJ_p+=E6$}!+pbaTbCFX*y=Y1*e z4$>NJwC=z0XVo~HK z?A}`8HJuj&{X7wWuXSa_=O7j_W!|>?%dNs23q7~Nxox&71 zhR?QLlttfHuW{BHx3dq#d~Ju*f7g31Gjv`UV@@wvgs)a+_A$TS>GH*p$}e9L$p|Q& z6q^Id7%Okdsdtx{4JwVz>E(M0xc2T3NWH3>t^(Li`(~2GDQ)HWu1P< zeV{p9+3INn>}-053G7Y9CEMu_24=34ux=w;r{av}=hi^{S)l`?a6}g>uWRJL9;tE= z#u)i430(sGa^)`oQ-2jn_R#B#>r8y===Kewm4(`H4G0G}z`MEiRirMhr+IeOVrEvy zUGUr6Y`%Yg=UvuibBMr75Q^^sl_;fjjX+aVb8@ib+mQ%w8$@{R5du3MnpTJa55SC9 zZ>hy7E!Z$#4g6E{b|a9Fm+W`b(^IISrhf5s2}}h&vq$=rW)4u8 zQiy9PM<$50LCchM@(nOwlHGCq{!hFS`!|bV)jA@1zV)Zz3j67a z1up`;pwe3!UM#tvGjy| z3|EhRvfZZb5%mEqKi5r!+}l_-n>bX?og+rRTr6;lQ#{HF*i19l z?|;Yy;xzM`=A&1-k`o>NJQ)fiqbX9MqU!7SY8D%X`8yx3jq$VacK=QiD5xesB@P?L zhanLySjeZyV_D@2?-&S@cVRH2Gn0+IH^aai$}j@?e2bd05gXqHAKpF54Jw_Z+@ zezd+F%xoxW<}q2(Z8m@vNW1VYSUN5C{&_YxLxHyfrWec(Lbx6(XWgb-+&A|THI{BDeriw!-$z6-A0<$flaYx{M;+cVrA7sv}qn`L2Lf#$9GmQc>X@5o`^Oh;N+ zx+lpu6qIcgM*m5#9dKW3cHMq!XWk&3`xRi`j1ljQOGeV$s~s%p#T&!;8L4R@t9Lw9 zt|0@Z8pjR(vvU6SXh-`US%Bq`Saood0+@HAc%~nzTHyIAIkGD%7XFjzc_z+37 z)xoKwxZSSzz1M$=kj|a3YmviOkG;&uQIF(NuZ>YmPE!p^m!s{gd;7krwj)UheQI6> z)Lp>-H_Tw!VRNHz%T@7q!#10?wL_(&maDBDU0yXvGXW~fe{wR zap!5qXjd)S2z_$1H(lZ!qzh(?W#9aY&FkRGhUUZObZg;G)q?Q=?>Gb%7Of^*g0Cup z+=RwxzcX0RsEG-txvjo>7}MXHsdzM8XzCF0g(rCQwE1`=hfj!qcImiyeNlLsOV&r5 zM{m2^KRoQR&nQZ7*%^E!xKuj3Ht$?wUNt@v8WoG>M6_x?l)`PYREX8@<#wkH+TuNG z{7Pkec^Z>0?by&aH+WpoMO0@0gO&I`9Z<8<^a8u)TLeUd_QhkVu0$@)Yo)_myrkx3 zG;ip(14!g8pXk&K?LRWLufnS>yk1>+ciV{A*Sq#_UDmnho{(MXi(4>%^>!W$6`n&n zeW16P0pF%z2nMatLYs+wG{)@r0CO2ZLNF79oT19DS>AJ4_8{2PUHHmlQccN(jH9~x z|8Vz~VOefn)Tkn2U;qk834$P!g3_%L(nw1fbayLhA&mmk(v1?*h!{vncXu}eN_yr( zP~P*t-*vv<=f~dL%S}Ab{j4=(%rVAvqZC53hTM?T{J-9Sy7)Fq(+6nrX@6D@#v9X$$8&HiZ8fCl}HV8w|_k2<1rou|t?at&T0SzHqoXN3!& zuxkxJ0$-#`ilg^chQOdnG`mSyP7n=sdpigUKjm2MGB43}#u0)hxQGclKo&Xsz50O) z4Y4?2F6{vG(UGF*%&QfEeCk$#iy7{96);jl#K7;cF41`|v49jNmyOk^s>S z35=~wqcVRLgBVpC)BvnbtHxUT&3OnqiZw9<=`FJIgUPclsh)oE$_nz9dUyYGEW6rC zN4Qw^0f}2Q`x47dn0j6wdv<@ha^i=F%3do;4cR3!^R`jPyO{5hJniWsPB2BZd}rJg z?}-!yTtjDw76A$+zM)a-xbmI72hgbSDa6=^lo5=?k1*0;+S>TZu?VcW0N`17r!TxL7x#_jDd_$*Imz-KV!N*?i^d;-KQy;g zRSXF^;|WHtg0KR2`ugE5h>+pyDDOTuN^XAM>!*T&qELKD?x^#Z)LU$crgG`Q4$kFuNs=LsI4$EKfs zPMG$ww^@8+~SP&g?EoQDk%3m(L6!gedyDmVnPNs-r+H(vXVLws2wEB(y4lRL}>7->?80 z0mmuY>$tca$Zc{IjdU%f8$UeSc;3?N=NP47Ppg_8e}ed0<++99hVhaT64=a?)^Gc4 z!a2SFB20pB3UGn0ZpCEztDr*7uVlH2X7KCZb1R5Tfu6PFc6g>N)Mx4~bKB6w2Ehbe zU5D%1LuCO=s7-of|1j8&!d{uR? zS*S+1*hR;lKL>=Z>a_F1A-NlGOidb`-j=5qSe@`D@*;3xyq>4iy2LPjFZcM`jDKDJ z46DvkZJ-eri7Mmr_qlQPntJAP5&jUe_BBN!@%+*Abd{fdW+N75bpmh~bLO4fOn44B z*C}`BZPfCGc6PqYAM8t%3j6Vh3G(j!W=!qjn-D-QRoYJntVe%*f!{Y?99QjlihW)& zqebpphPw}wZklLU)B)`{Y~Uc7ch3`$IBY#!y~;@v&wC`?*(o#k_0g_(IRQRL_`}yt z1gK29&6;gVRJ(1vtDQ#svW6#n{nM14eoLP;bNZy*hB#CJ(NgC3Ii{)JM_WKVyN=ih zhtqw6!tmu2Y(j$680mu(T#j%OUm}HH9V}BGCR(j`+0NmTkj@{C0w@2c@}P+S>=5b% zY~5m@UYqcK+Xi;V<##dC*s?e7!ER4URghP*DepCUD4+i!Up(WF5gpUMe&WyP4QCVq zhR4&P>|r(Ox~BD}Sn4a}BG2HwA%&x1!fp(x6cIka3k z!)x#U2b4Lhyu;Aw7HhqKL+%7hpPf^VojM^^;Oa_+<_8F5yq#yPiY)u6AGUT6lzT{p z5qBUT1V326?#l!ONv{t*4d8cg&{L4L2|V^|tU#L&0rqnhc)jaJsBXeq?I5627I*H4 ze;-4Dx#@Qz{8Y~_CX8_dA@u;CR>2AvZ8#%MV0uY;I1jwKku$p75nM;<-ou3I_qtN) z5!o!~y$ztMV{5?8_c2Viy84V{8s`3&r)0GjhrVX()t}GNSp|eV5G`JKj!TtIbV$Qt zzm6aY!fdSQlIH%ZI-<{_BdXv{pLFYRs^jA=7&qKk3QhUh`ij;$--ep_0y~C*fv?lO zZeG8iT*dSYlY9l^<<&Hw^fm(9i2E}^YSla+ad{D!mS@Lfy__h4TR7$U`WR&=@=qtls zUp`iv8WhDjSZGM4A6L^k6);-e`My_#FgKu8soOcfqf&H@u<4YxchmXJ<4+CWP1+^< zyY_YLTyWF>b;0A1nr5oCuT1`Px9J+Qi@jXqd~>8pjNxnf@ByGj#1T)bSDnf0FEAB4 zSbRR@_R@QMjslb3DYDnIA*C6y!!s(Yq;yKl42GB{XL%Vviu46L;Rx`~iWKXMIc*I+ z7o7&qW2hB-_Togb6j5G}@syG(i7k6}-fUh_{$OV6^I5%glZ^lY2?{d4B?fQIULD;R z#$6+&3vWLxF&V{S8)UFQJrk7Qc!OS{0dEtVJz72wi@Dn(XU#b)V8tv~x_!p3dn>(v z%t>{!Gq+cyZJuk`Zi_t{KQFr8`Dd@hR^AWZ>^gLDZ{5gMv<^5qxUW$^fH4S<`Lluf zUO6Txi-21W-S4m8yo-b9#7F3PVj2%xq8T+RMLy1O@NYM%K6UFM&zS6~&<>0Su(23@rfHWBc@4 zYg8a0=|ClgP~gwcPKQ9zArhqgH4Ap2W@~*rRk*duH&^T8dwCT(r)*Fg@3mEFXHN{XfZOloga4bK|W~%`%3>qkA|q)7;9#RK6A8k(lzN zvwWf?%OX6%V8kPJ+IkrlW~eke`8FDG@snSt#jh)m!Ki`tsW^um*MzMG+k+3lLNdMf zkL!SA{|7a4b0InKhv$ePjRc8Mq=#j^4KTKmn1N4_G7fvo0YYkG9g{}UNU~af#nhTj zQK0lmbJI$RM92{|;o+j>h7blY+Kp+W{AasqHPedIj> zITC&G8vhY-O1F8hzcN{vsPaL>qaiG~g`2=)@Cu9drGwm$$L6!!89<`*H-_IXd|cj7 zQcBoQGVkZn*&X1=ZAgaV0{!$^HF0tAfP{RP-_FhKf3z%fbRSiI7=>zT$G(5^PXdQ? z4W3ulWU5d&`SLD5MYiYIMQxqP>m*RcItU+1_9dhApwj_F*Ux)(v{Pe?FfI0+#J0+` ztSCXFdu9PWWk)QV-@c!C6~UDOU*-LHEicX&iS6N0Zx)VS{H`)ihkJPTcKRCj9n;*0 zbJbjSD{9*3R|wL|YsWRMYUV47*FNd=UX#Dq5najsDR6Iq2o;MP4NjWoIU{G1g3J%f zvSeUpW6EX9ZxZe_pD3J1$j2IZbho{68xp3t(0y+u9hdp-r|V8=R-GCfNtl$j0SyJE zqpb$iby6R|a84X9`UvEHxW>MuwLQ@9dSVM4Vt&hg&FWex>ESc~ctsc@AH7bcus*?G zYFccd!ZtdaeB>Y3@QaFaO8`(HtFBo9S(hseY&dRn0f92tq+6*vn2|*Mk6eW&8cq}< zNh8-=3yPSEvPv8&8rIYMn=)Z9*R=TC(Ytk4L#@U41bHuwW$$5tx=W{KR#;qua;dKn%4UqDZIT{1`h|2e*%(&bRD z6#y$`fE|6&!Xe9*R^=K1Bm{tKlm}x^p>62KBs}rlazk{7za0AevKL@v5g)yrLbAuN z62rJ$=T?$LN%eA3we%_9CTyE$-vVzg0*%9j=Rw9qQ*>L}$FV3C8C|fN4+w%S!Rwnp zRMN8d2}(>)n80}A6<9Uhf*IRquT#C`7V5@)i(k%we*f*c86feSGndJUl`Y2er8>V=AUETAnjl zoClknS9|kJ!@2aR!d>q_PZ3iyWlxoS)S`U zPedpx9pfE60QaAF90`pJ=kp2jVT#8OW|&Or^7r>SKeah$ieDX+gU1mquPI^94)EJ< zY4Lu_n)LCOX>s&Q_d{lS$MC%F%8L6PFN=K>;+Ke_xNvsM}P{fn@I7 zwDBZWv?Qj2-xrl}lP_1=9EKBbv-dNzXZNX*lHV%pp3~c6@2?Sxy1E$Z&?ix@LP!Eq zm8?a&)k2$%%*^A)Ey2AqhAX-vwfm7 zx!SDxy?OKO=tWP75MXE?y`rZTm&grOWO;p#(rx4(9~)SIq|=~hi$w%#P98?D)x?hby*lIck)H_j}8^`3C%jmZe;x6 zt$ftJy@5Y4N4Pf7Ruh3e*U|jr6E)7r&pWak=UbeodFU%2G55#>&%7_^<2fCLC-dgY zIBBNX8G)#leTc&TKe}PyepFGWkksC7l=eYlJAV`$I(9Ezf>+CMw|)4p|L=D6UOGJv z4eM@jKuzrauc27PywXM;#li-Ed64oQ>;bqe7yCG#^Fy5gnDnSSIRikF0utydNPs2k zKLYlBJI)D6L*?i69FMuV<9sAB7NQM6Ay)7NW>>x$Gg{iz13=o+uVFvYAY{6RNQtOK zs0M&4K{Aw8?g#F%e=i2c9PrbKj2w@ox;5)Pe2|43K{FKTI z{ID^{lcJawOv+_>6Tt|RlarlI`zj=yqco5X@A>kSJ{7!pP^nEAZ>L^&rFMh4l?=Dt zoUyota>ElsQ28Ke5mM64eD3`o3lwi4&xantadpM}sPc^$eKh_DFy;uJg6)?gY!wN} zc&e=F@Rj%HY`6FpF8hRes$$Make zC>^7)PUu?woH+NR*ASMIG zdHxmHyFp?J2SYyG>1kwm@BoeMKX#p~AReB%jW8IRkn)C|aO1?eM`BD$aT^mNznea0 z3rV`;6Yzt4Te=#SYubB2y#Ey?2i#nN<9Om1=9>+-OYo_F zlsR|+PN^rQENCatT<)@_a?7Ly01umkr&(I4k8_XyrFHEy32;4p(A$P!0SDDI{^SSk zus?&8Uk88#nqrcWXokPcKlst*)B{uAa0(+2E%O#&c3_e8WF?KmBfEgp5M+&Xwjgb& zZBYV?3#59iUG2D@97Xr{6v88a1&=)N6lv{KK~nU{PvKlfV6mVn=Iz_J(K41})OO(1 zazo=e)~;%3=lmB!gZAC$#3ByuHl~WIGPt^*|NBP%T@X0z7mvYdq^=Tdl>Co7{HJDy z=Q|_=9MDB_Eytt(@ojte-TxxG>;J#~_BRTP^A8s1?i(7Wl5(000@klw!XrMv`}sT( zyazDN?LtePPMD})Xc583SqOm|q--;VY^Mf!aoWv(+W*{i;6u~o7$&QEP25KgC*K=6 z1)WW~(+UF6V0t}jm!ZdYi!U#g)~qvz20m!VC>u4oW7cgOUgW{xH1*1*^kc4-_74l`k416J|zFheNzw~FJuGD-sooY&8_QrR+8D`D44oP_`KrJ| zn$hrkRB2B_-u0B2(MOvd@zTfYb~$c5YSIVEe({{|@V8;P7XmuJxcUJv3CNdg$Sk6( zUAL-;ZcJ1<`b8T2Q?zq*vy^AVPms^`ITYuaOiRsqx1?ZC*$3M>Gv^C2mCU((d!u@=jjv^=F+mK?)uQ^|MwX}Z=Ny?v3-Y4sK8*~HhDMgfbU zvqYwPs4tJ8!rcavJ~B*`Ex zApFAKk9u@p@;r+Hx#7C2B>rx(f8fl~^ej`*0a+rWRe)A$0f;T!TyHN64mbOt3U9YT9chB;Sq$AR9EFMj7vSc(h0I+%(WRxd4y)&%$o>phN~W-oOB zy?(9Y@8m&KKr|}OjT28@`v7(*OW@iunZXNW{IRL+P*$B6g=Bn=*5IUtYZ%e2=m5gd zAjFgT4+#0d z;e%wORom)E0qb9m07~5un6#IC$ryJDx@Yk*^fKh0G*I&O&jQTMR!7>y81qEdL{gE| zuI*LXa+8DzP0tnNGjOII7{8>WDROWB&$b)0q^!W&W3>Q`r;2RJB)_z%uvZp5*XEJ@=%K*AEXHZuDk(95YagxjJgU&GhA5k~nr{XH`{XW_<91PFyrFQ54n zll%#haSN~(epdpDPaP)zq=Y!O=9d`+w45nvpGQydqkzjy1oq3FPt)k9z*3Oty&<2P z3~~{%ivPdV4Pd}0hsVeIMn2kf_v6e4XO65ncWkCgCAW}g`t;o;uUVBh0j^^?n6(CD+$juU;JmBY2K_aFadub-ET{rrICSIOu>wam=+|jP!2P z)^{BDKDx^C{;S?7DxuV$p4reo0JZ6@n?efpzWk0Ossw+n)kG*RoT3RlI0UH$xy!u@ zq?QQFS@{otH<(G)z=W5=Ft<6Q7;*#Kco#4rWDq}S2qauDGmiejVRD*e`TA3hcg3L* zgB-97xsiK6SEUN1wX6;mFM@aL$G`x=7~*N5Z%W9Wz6-I-r#cmWzTEH`+!Vd<2Amc{ zk0(NH8m~V)MRvV{TWY>B?pX<7v$6D}nohH9jr*hox32lAPiZKfv;pFOAkg_V#b420 z*oU>mDN2z3OWuNr`wi6Z%X`%CHywClT%(U~{zOjvW9kJ!J{?yLBCb1{h1ec=z(WPK67wS=MU~YrL=)@>ytdhjq>ExuNE&7 z;8_yH!-zTiPKwri-toL8rJxPkhfDDqqj^j0^MoX-#hTHdoy1v(_BS!+pLbq@`Npm49)|tdU8|~ zq;@9wo13QDYu>HW-4pKY{U`6__b+u}T_41I*LvSXZCVc|Jbyq6ll*|6blou?f96?; zYWjP&mi#vwxhnb3TBXxXSty!#6uHlV|9c2bep`6cw2t)#f_eVG$M?7K+AT3CV7|=W zfm{RNq+CKae0J7m7CXS^pWzf_+CsiW<#Z=pEcNt95(;uR4@TXD#py&9?^|_u# zAw#LQ*#61!J{`p-TZb^^m!Db(!DH zJ_OPik(I!7o9#vlHjP$qUYak}xH!!BGgPv!Bay3CvK766gKpa-ks0#$M{6!%k@p3T z(0%ad0ZdH3z*~PSLEwvn=m8+~!iiX~>&7*lX#qgvIP2IJj#gZNDN)YX#l0_vM~?H~ zE_5)agD`X%HuaobY#oPfla$Np9N!{Ff51Z2DuQqOTlIISRd?E)+*RxFRT)VaJ;NS@P zofYyMiQIjEM?uO)-VeqSgoR+-r2{ju+qD*0=E#Q*Y-5=abqvJ+)$kA6KJ+pv$}v@T z62409=?NlAAhWWcVzy4314kc|E_R?iBgK~fdA7PI&yez>3FbWuv{o)KqaS_V{2X*D z#7BssbY?qmoJ_VfK11fb#Xq}+89h)hE_t3X!X=&>gleyfi zXpUrB$kzi_9T;od=pe$^h@QYkC}e9!QHMrK5Ve@r4I9Dcj>V^H?+sgJRSgUb61P`; zITUpU*z~zj5E^`cAPBZN7BBlbk7y`SPq*j4{6><_;fzZ!s>O71hVM*{F)Q}3UbFCZ zM&7CsWfTMvsnB=hZayDh67O|N(78B}bjiac%QNE=_r{SFGK?8vhYeA$X= zGadvr&CD)D5_1s_KyV`>l=yy_@Jk?%lct2(yQJ#Dlk7aQQ7>2zmXxnw<<9$^-I>=V zqj803+O)q9yo3h(e}e^`Lz}n`PheqXn%f*85QOcrNuAcylZf9?mm+^d^FGN`?n%I- zSkt;cowOkapNsJtNMkEXu&Y|bYGqvyZDFxqo@g9`a`(EA)joHVQUbq!5uCOS4XPTb zyjIR8S?xw4INZAF8?cXd1SB27KtTLlA9V~c9T%EL5I%O5xYnadu)bcghP`%5b>jOEsu-t_PMIvj0` zV_Z4X$~CZekR|+CJMeWxt<@$VPmjP9m@)giAJ0Z4+VIf)LBBeH`&)yuy%cfGEc#eZ0pup&oPD~02Z+N^b& zUAv~Y&riEOhJm^sF5G&S@nR%FTzc!WC1F~+qJ$unMNe zgsBPFpBx{H;N$zO3uVLD3P{a~;^Iw1S`{A~zOniYtzB(~<7*#+o?noNg(YbA*`o&j z@x-M_;*Z?*dV75I;+OK9kG>hWMl}*Ra9Tm5<`ZB9|Bk^rn-GpF+zPdf%uNsog)VA_ z-`%jL?gK{ZWy4ny(9i}b6ZJBwf$I3jIvf?NvsUmVFUu!-gEfZ|lsZ=NgxA2Pxb^r8 zIz9ex2aNC|7ITiXC+T)of*4MQ9_*oxNBDhm{@IcvpGhVsLCy4FGn1}Pgg`m}@J zeMsiVyd=l}@g^_>;Bh@qBQIKiFv4p--1{(K@-N58zXua9`y@L(J$=0P<6*fc*Og!4 zzkfqNgjNa_g1?J*j5Ob;3OpEU-_L$#<^ky@R#njsB)-X5)nJ}X6bb|SZ@2Ls{-UXZ zS)k6l?0#M~J&7}Nw{01bKr`xxMvT{qmQR!-!%mo<`NU*Gz0Z;_T|d(NrknlK4G_V8 zABJM~b6o8fQ2V~O*1n`Vv|m~;41y=^fb-{xclExT^{Hlq-=?<>{=ZN7XqE}v~#jQW4uuh!wmwI~SUlDH&MZEiYqGgVQTRZ)y zo8TF-*ICjhtuif`0vABINf~#%SA_#SQnixlfy1e8)2bL-M@@MHfVbuO{{EXty@xCx ztqAUdib|>nu@-!6h+7ug^GwvN=C>Qiaj8U5KY+|mkY#IQ1{p|D0s*(FgiQj%(-ivv zo(O0SFN?gr3F?hYf7KhFg4exmfL_43bV9=~_Y!G4yQ8i~E-1R5`De|6Eq@S4pZyJF zEKu}BvA!=2uzp~pZ(E*jNApBuLnik!gb?D^r=Hb4`B`K4NB(#zh%ST-xq(5)H!I5La5 zLeAI7mSiTcI>W~@6(?f&-XrX~%V(6zEvp?eyJd8vNPH+TH!r?34Eot6u+;)cYuD5R z_I{t%G`079U$bWuYJz@fmdzeXuO*Dj>xy;mCO=(zgNM=Js!m{`sn-RVZU*EJnx=Vm zPF-a$LS-sOUo1}9xTHnHI79HNiZ^NOp9J>bz@p$Oj3?RIU|}J(v)LFC_o3>;=exwj zn#|FGDnHh+e}=^ecAD`M3j@tIyS0npPObUi>+!;|FryLBJk+gSP4Je-Diy)}>5{3xiGIFJ|n}`wYrRaWEa$0b>>fr%IGLZ;i-*`_YP=Nucwz%(2*GrohbO z`HsaH5exO6U;)!z@P(C9Ps+{JWuV$~U}eCFj3=Pgo{s10RrzXR3nzZZWqk)&{-IOp ze_`pFz-IZYM*|-Z8i0CWroZeiA{ihd-r^k*h}t*n6+&6GMA+jZ|5=qRhYUgmLL9ma zfwG*<1lLiZfJ&8w2;A{S&XxAKwmQb@^*n0f8P|V|DGgb|Z5*xVYll;|aEG}o4|lpCJ2n>5`l)IrY21i9Qxqp0~mwp{tJ;!3^t_z|omHKJ&BXZUsZV6?mm3J)nDiPM$ z-CL@<2b|7jo@iZa-p*C6iaa;xiI2(}qklt$fRi${n4Vh~%;Jy9wBh0>|#i8N~BFWnQ!-neV<0 zhl`cIy3Qv%_VQPjNBp2E;nLoEZV%EK$@h-uBT0~_o@uf5%@loxh;$2u%~mu$ML3K3E_|F?r(_6{*vGLK8bY4>G#!pqR=JvUyG-}7l> zzKnVO1;OUEHa<{poKqWM>E9R^^ZAKcvV5nCb}hV^q!l!pLKgX_&qNXZ^-(YpWGn zUYP^*?agP_}{;Uh}K?o;q<$P4kK#!eJL z+ue3{$MD1`_v#pU;{rX>7?-a&6F96icVYUS08@@Jh7u@3NA!KC4cNPg?73L?bH@ZJ z&U4@6t$ruubKc;_Y1O^j=#^lPkIj_+V5eR*o2HzW@Fr~Utps(|MaQQR{exgO&nksP zErqF~YA3vw37 zu^?YOqEBCT=kRY-(A&%e$pVk^@4R`)G#5s401~p@dlhp&cXbA-7;USsQMxeuD8Aw@_IH~SP(*q;9 zoB5%sR#<%FH+=5EL#fZ>W|TV=TXq{Mt-A7myR6n>O23%8)1%kH(RmZ|;MMfwbaE7! zLiZ^Q56CSF9XKQPO8?7|2NVmj2l3t|Z(+~-NB`AJ%kOb3IHd9Ti64Fqi* z$bwU`1lt3jwScF2E16LP)1^zI$Za1M@VtUbwa{7r5f~lcn6(iV?J^hu96bwc!6h2m zQQCu(hj^hnu&XLCGBQ&1>7l(tL*bu-xj-e;=dfA%lA&X_dC(?yV_=rGWl)W5v+0a- z7$ZW6Jo4Iroe2TFsdb#gCkPoJT!c*A1u6882$`{R!2f1Mmf;eLA&x8ZqL|>eDtg)j z9p5B|FR`~5Lzkw<3yEL`3!!gFstkEl!^x%EIQ~0Y9|*={z(SYdU>ksWt5gl$Fe`Ms%aqhY5eSiC)8> z=emaP<(y{vaxluAZbi*Ttw2pXsIYOBN_4o!{G4-m>B0KMK=HNbZ%Yd{?r%Nyy(;9; zvx-k+{s7urnUN|%U&p$Z9Cm@cKz(Pivq#d_zrO8+GT3?iee8r^I{8&4FZsbkcO6p*!xyWUMm-lVyd+M?K9X)!WYSmC z>64rWtsinV^~uFH`|P}ld9)2T-RpPC;za;<2rkv8^2k2>{P=cCe`@!qUdpsz-g#%z zv7NAnJKl5Jr_$j2zUb$3x8XmQ=Sks89(we}+S>pe0oLC-T5m6?5b~r|24hZBdY-Ma zp!2vCnyO}&o-mI&{bHoGGU>~&yU!lBd*xH&^P9(d*i;O1Wl@2NTp;ps-}=fVTtj)` z9#JoW&hmYAb~TZZlvoAl<0>E6T5QuwG^rG+04FWzs}vhLlB zz*c5hcirE^?k})uap55bJ?|E~v@fIY`=dC3TYP5JEc({TX5H*|g~-q?CmYJ%kyVIO{GNf=Pe^7&t>|M#II$1wzxA z1)7<^gS~@C<8cpeR$ePoeZA+-dKE)B_+D&*W#d36U^>$e#|-WPq!hs=5Az^u{E`_7 zz-DX?q9u)rPwhj~2^jBO1eU4a*#M*pf`zZ~kVw!L&&msCEL|`Tu0}G6;ctE+fLx5e zjM*m;ee>(gM?Z$`pGWt6nNZ%aCo<&P4dy z2u$^ioIlo#v;cMs#+#lcklUKT;l4q~`m82In|QEHLj-~p*to=#v<|jmjL^6QL7aMS zCMtHtm#oyciVH8_P69rg_M-3;Wtuo|V&d&$Tl2+KMGr8{d*y1XPME9x>3wBt zK-)4WfrFQ^y6ek^4#Q%$jWFRd#eBK#*O>M9bZ@mK$+)q-3>NM)bi#`odd(b)zZuL$ zmLPTc#|*hVi`})@!LNk*t0FyKBV%D7RH-tn@87yOuwh0tFM21t+)=1Ub?HuoH1Q6V%yRvY`U8I-bt~jGuf`C?IM&8SlN|)o6S@`G1 z>JqQJeDl>buAr|8xm5A&`-H)}Mc#vG=QK1Fh4&)JljR-Abbi8yuL21^0OTB)xS*}K zgLO#HW#35Cw-l0an!h3CwU-6!uZaC2`G8xX@Bx8r?Oqwj-F*XM>~56Ue@5GgY>foZ zFbh1rA|aZ%1#b9@lv|=VY=jHr4(1>)Fth;|txWkIEg&$pZUw-Pm9EWf7hALg)D+g! z?UQ2MrBQc}zeFpsjmGj9bvpJDx<;rWScDvL*48{vy4&iB6T4BW=16;3Bds@DQx%(I z$`CTw4_gySObagf^+7f*!Zib|_joICG}>l4U}c=b`24HI zr-5nTVNRq2E{tg3A;(u6czH}IHvx*-N{lgqDzSGPs@>~0Pi6}yGv$b3AJ&~P(Uhd) ze@LZBl?f`#J88@Z2s_C;R9{ zK9zjX1_BLLDW(MqhmdT{GmR(GG>Y@m+ETKdLH20XW17o0hHUeoB{-{5oC&n{ObyTv zCqUyLyX8mlT)oV>)QHgJB=6lzYX7Q@DhT_FR9Dy5lIC|_h<_nw;Q91s*B}~A+I0QE zV*-_mD#fwf;4y%gMAwrDVwwVOjSO_ns}fHnf1b$)h<*C zA36QlHR2i&=ca<^J|GXdVyei^7>yGrVaZ61h!bq2g>wGk&`i?-bx;rJbiBKNAslS) z=n%C;|NmVKR)PFSV_!K~R^dxV4+C*)6Qml4OJ?Qv3B=|SwzjtCiB9dyf4-0qzW9yo zEw9q1ghJ43GE3V|{c7pzWV0wt{c&z5bfziTfanVmSAxMofWkVH)|P-lWq`Ge>Fx)0 z@FxLyM{-ipfG`2%BLp27 zH87F_;lEX|iR27xSVKKlIEkW_cTk81f!1f*T83;#F2CrM2$j268rVt>q^H=cdPu>h zsDU8@)z&G#XNeC2@+-r6j^UP`~_%K^goT$`LUs!7S zXhzhLw}Z*k>2dmXySxq@;70;ba@le5)MhxW&{6e zpyyer%6Fr+;ObUB65nyMpGOuP#YyLhI&qOn-&qSPJI`PF!7>6~#+$)3N}mj!=pwb! zOyf$vmgi!Xr2G1!l=y&r;8&)|J0=!irH3sB8aPmCxL0;eof#6vOUJf7bF}U7u1Kv=<&82IrB8Rb<`j$22v%J0Kzv9zjDn(Z_mwXhV(x>r%aoV(B z-)Nv>Xo8f@SDMGNVP5!M#;EcwiL%x%g=#gMBXN=EE>9N+8m_%LyBwt|dKjzKN8K4B z{aWb7l0;)z1Q37Bxszx91k?1Exfxw>fJM}hhiK`CBMg1j?1>Lnk1_O56R66+o&D4{ zPmp#vrNwOJGQ$#qs*Y(;ejkU|Lpl=~9sk>*2~FhLAB>dNDC#9+qw$B;3*_wvnO)Yr zKnC=i%jU+NefL+8!ux=;_PG_tRqzu8QkoOg zgnKdw3Lg|ge*)c80r-qm;H@u4#jsc4p_#b9A-uJe4~HgpAy+{Pd=#gU9Z5XQc$x~W z^pHg9Aq*>1CN6RR{Y?ijo6xHY>%H=QFHcZhK;M}*Kz*Qr;T^g2ub7%m$KZS%V*|5< zPnBh?6_Or{O=vX|b&Z6Laz}L6IaM@$fpRq@OtWZJT(f#PSkwIp)NF;u`2i+d!!U&t z13UnhORuB=3pMcN%VCU=ZHK&ld1LL^far?>(J>dMKE4%4o1=aFwR z*v{_h6wv%}>Dxv6cj?DzxY=^5&OO$GD{!svgT=Gc8&2hU!FqL+X>E7pzo~sNNY6{U z-IvPK@iOr51pqIES)Xmfs@mFkq=BLutF0s#Imw3)^#GxFobQ_s^V9>}JftnJ8p zJcIF;X;xUY_?A-iCrYyb#W5v^=|?%0NlmtCzfC6dt3^Wf z^-HzCFSXg#(vj!H+&{@&{N1b8e$?YM-tY~YN5sajXAQD8V4rcZ;0{Xxil-Buys&i_ z?Cry`E7sh&`BLI$&U50j>OBtcpN$Yq@s^}AtFalUh;Z+!EC2&h&U zlRDN-9)`N-al12h^&6TgkRAS86ADtSK#Z1mIsI-|Z6J#WC1Gt58bY&JG~GN6u#cUj z$;(!Kz#W!8VKa+eHv2vTWQ+*WT1b_>7BJvjYMFMU`KAwWJ6gyLiH-tjZg%qOAT62+ z6mRiieg9_C(;yAX-x@o<{kn>q3sh-GoiCd0Ej;*v?mpNvCu5`-x=MR7g)5V~+aQjp z4p^O;t{cHRt)-3(qgI1PILP91^c3=P|&%VfBLavtbb^tk7WVF>c z4<4b=Diti5)hcD0-TMiyf<>tP+dMTSN>KMK=g%kQW=-wCd`|;L49B&>Sgu_RFr@_O zXB@JV326PloJmv4u4lA&^5S`K4!b>ytK%9c%_lMO zDg3NCw*nwo(fetDjYql^Gu2hE#m%6tu>kRI>s;>!nqKi1b?w!lI@V_Q@&F(yFC%>Q^6*W9fOxRtgLJLig6+*Lyjq(k$n7iMd{UYIENV% z@fGj)))NgvKpcIq5?*Xu_UkM2&EK?p&k;^(7F>n1<2QyxE4IHk;O#a-5JVu0gZEc3 zru5h1)2*qV>W^&9NxE~7F>GjL6uIdrHq_MIr+?<|<+S)f{<79XCx&&#>@I_}lEV>- zcEMlHtLJwRNv*!t7Ebdmtn0G#+H8zt2`73Vv*M6HkzXk?D;W_a>igv~P8+Xpzx$_n z>AA%0`QULWCx0C$4e9Gg-zgVaABZ_anUc#>&`Pb0Cy;$~Zgw)s%U&S!2oo6(7D4ab z-Z9a+zVj;Y&q$8yrPo@9GwW?Vlu4LX>YZeWuHFb1?`?Amp8DAn+dt)^)~Py@QSYtT z_rOWE^X7(sg}HEpDaAiy_`h)hZV~j^H}S6co_tZFzCBFU=l>W~3F1(hFGWOfm_IGF+$tUPKUVDDBa? zfmln>_J|xbfv5AgqxX#CKKZ`9;tsFosI&H$8|4@NydsK2kXeY+eP0_vWIQ?w zZ%rE{oR+Yjs}dG<%uluAhV$5E@~(Dh#n=jG{pjxBZaa)E^OK-$7KX|#uxS&$#Nx?9 z=F;UrY?^iU_0@;ql5A(PrRCxT_jjY%A#dGl0bo4!D+{2>LzWJnz6>Q${q$6RLI|G0 z{s;ci;?FI_-hT^w5ncx~l?!>mys@NeF0!2^PKB881MLcp8IJ7$=FZND(Q+(QuNsMh zfKPcUMNthZp}O}~Z0Camg5GHTK)oOP}BXQvVGzw~h>D)yg zp5lgS^W7^0=SG_n0x%0SzMeAtk9N|4o8%?`WrKGB<_AGQf67?~0!`DR6s)vOWI%LT z{)>aBz}B)n*VL%oHW(AXz9gH(KAGFwEX^YY`#HzL&*}33aCQm2>UV&%Yn&XQun)1< zf*FuorVq~ENx`9=gd#wAfl6XsUi$a$Xa@}RBs;FeznNX@0Yu#@O~5I5sbs9l3Vd%z%5dj26e2QIb9%Tz(C>QxdQ z#aFmCoDsYr_t?1`D%(PhGM~Hke;1M-6&&C^xGAXy}pKh4!vH?4Fu<0|*yK}Ny zI2efoVqFG>&xlM6l>xERHrWw~eAx~^bQ|M$oT3}=#RS8H7LI!1NXn1LQc?~=!vqJF zCEo?Bf1tCQHH|B|CwXXHJ#*e}G3JgZ=J*b268zzEwI9xDsRUkcHJG(|_?L3CO66o} zf|@_3D7I@2e5z&><3X8qH_CZ^(!UF)Ih8<%eXZwc3;|*UT#9Xl_-fCk37fR9Ok4XD zlY?<^HqcN~z|(;zA0XuXllAV?bkdu{?M5`}$NzqMu7arDYhAjrXt+QX_aJ632tPQg zP4+|Fb5bgT%mHpYELtOjt*H7+ZLskfSj)2%` z#FuBzyw;{DU3#{)3~^@qsgb^#<77T$cS|T$AC7Bfm{ZO1Y7}zxlctGn+y36_Uz>iA zzLu%9E$6}VqDXgR^=mVCmPx5lRN?~--Nx8{%CyM5YPFB5HhQwCGYcz1fK7j;Zm-F=i_jEB3+f3peaO6Vq?$W9{76T-YC4 z-MgN7XOyvCQns@$(r;WxC5w`UKb}P$HmqJ)uf~L0$*cJ4`Wa-7A?dt8_9wz@ik)8Ln z-|=KT8jm7hR5$~6N!M;w03+~Ic;S}AkFhE(KfbY5z$n*w$B&#iH{c6PtA#qD0h;Im z20Yrj89UvI3c?si+kJUj8V{Z803T;jZyn{CC*D6S@c2>2vco8&ZRg+24b2CjVvxun( zldIccsj8gRd5=vc2?>!-A6`{|^3xA%9jeBuwlB@Nk5lCjIU{JX*&UpH=Cd2uXFGQb zH1vuiE~1d*uw-!W^q=(V&xINop4r^yFEWX5R0tVVD8NXY$s^TIRrK8>;8$(m4=|Q+ zm=v~t9PMk4XgCsdS@jZ9sm;i~aE|qhCiwqfP($j|c3sF#u~hTc_Yu*n*$1aZGJ9LB z)c>O=nbhqVw<0hTeeOFXUmEu1`qkRiiAJhyuH80d2gOG)I3NhD{}6;=4)+-qR-Qfp z0oH&|-jjUoWRt-IPddfa0?BtD`T(Vow*C?+eLCTb=o=d&X1;}xyW(@6-nomnqV>1> zf0KzxKKYPlW7(qZ?`6N(_zIud)R0b-@*YAXD^nYl93;MF6y;n@Ay<83JNsH2-*H>v zl6+_pUr08U!64}P(nRM-GRSX z$|^1*AvGD=0-6x})UC>`Rukv={373!_@%3D+69-4*Szy8@60MWcsY9+r<%w-cCKM- zT62&Ych;k`n=UpC%ICc*q{Cyk`rxXj>TVdZm@?g*QZ5SwryVggjk~P)Exsk2k`ev^S ziHO}=dv%A25Ajg#x05|vWccpCQym7TKN?29VS8Xo)T{KNq}JBUwAFaBSXjv>Tr^45 znUedVf|QsBOGlIr=G`t@GY5Tf5|^+nU%K2L^UsV{h8au2-MXzNM6G)FGqZym4U~H6 z%mz&x=cO>-oTY+rA#8rt`q3DP*@^Ut<{RZ(m1Y9R(%X*R5oK0;$Yq94F25K&E3KO4 z{2}o+e^AlN(}X3SPwOR6DxX(L6VT4;P3>+n3wn<}HdAblAQ_!o39~{rm)3t=m@MZm zEeDy2txKcnDu_OZSQFr)aCI3>?Yp0>uc1xE>VH_YJ1(k}#Q>C--rJqXTWT4ALPa5}7z({upmRZ^-*j13fM0DosR9 zEPo3c?3FgQO1cFz0EGr?JNMAdKYU1XqTY}%e%|)7oVxhx$W=K2-E$alJ2C>Q zj4W(hNE}iy_=v&z1e?8YK_jW%madBUmd~cV9*7b$qAoW9ULdeC?>5@}u>R5yG}8;H zafcjKEEwKp2?}7Zu0G(JPxh4-!y1%k>N?rtm3hFlLcrBZluhlBBQhfVVAW-|zx=R< z(=z0IuENiI`=RXaVIDaq-nqcrsfA!DuKNh01uK{^;4_+1JuYkh>wyX$|O-5v9rAcIuY}tE{Y%1A%@3Kcm_Vyef5%>MO zpXaaVdA*+dk6zbxyKlZrWVaeZgT&lF(g7>$5g2&kCiqb39_QbfW!YZ~4?-=$+Lz+2BZ>w%6QwVbk;h!I zbwiWDKyvq(F`!^#A!6S?(%JSPts!jd^nylHfUV+r!a*K-jY|Hq9BBE_8o!;e3mTZw zd=R>t^CR{o;HV^-j?}-kgM&yGE_nYtXvu;QuuR$G>fC=kdXD+^%y0nIcD^bE? z4G?7JkY0Amlsj)m67pHl3T&kV$uY0Jewn503Ee(0L)UIJLH@pvV!iy?`Q3(Dh}!<^ z`+TaHUs0EVbLE@T3nWql0LjfC6l7?F2{A`6 zoZ{NN#v?#>&7xp^6HefMK^t)|m^$}juvjE;)QP`)@*W}|>m@`J+?Fs{W|h#E)}Wlp=GK@jkro@UVk<$ere=6?_E&<-{a# z&hxGe;_HAREdf--XiIAj5SFt__XIG}B*37Spy^@rr+skPU zQMu>zqWaO+Yh{N)=%l9-`Q#PBhU`qNxr>Cw@#|p}%b~9z(^q!UN>3t|#`GNf? zS7hmGE+*?{zmtMh-$=hL+3^=RgLuFIGx1ivS#$4ot<8B%{v~v%iRt6(50%+{=;d3t zO6At?PaNZ+k!`>}Q9Mfoi1>N1$Yqr)sN+A$v*nG5lM7jr?nr!*xal@j2>P)c`FREP z7O4GCixOvJGq_`>Tu59Qm9nfYlQP%Dc;MMU`-)+|lrQ2rKc>L^xWqs3nj=h9;DeIg z#sU*43CMApb?Qt>3Q#&dvc-UC$jI*r3Ifl|95i_}g&M5`Euiz#|F+e@7u3lO<`v(Y zi?D^q;ig1%TLYg8Fupxhgba%3Y<2b{VPY^oRBv4dCPwS_t|T)8_x`Fs#63eDLcW3A z13@lvXvMdp5|BDW&djy;QL)*%pXnu!qG2E^1VEB-o!satfPf5$RNaO>Sjac zoql@TBGVUi`%V@z4g{@C| z%C~4;L&Tm#$&)c{^o)vVTq-CE8k{rWY5S=6_4?jS@e)5< zSS3c-Po8RpPELg7^N)PXIrUPF^4i@j0Km8nDt{wP9Mlb7FpCUhW3A5q>*k=U{0}vN z;f;+{bu-2EAl)A_HAk(0U&wu+>2Je&RjBtViQ_#-SH6kX$1hYE>pyYoKeB%HB+}!k$A%`k#_^60o4pwlIMq5L&!Fr-cx_@U?h-L;M^%VB_GKZhRmh3Jo__21KyjGnEvN9$mGlRh z<2^OJPpw`G-N@pU&ngy{m0c%tYcZg|l}1DA%e9&OR$@ynym}S zgIg`TqGp-J=&w9vNant4sXm!kO30-*FOYEXs&FWiOdrl>y1VX1A<0-)CeJj%lx&Y| zD*`U!dt@!ExTWAvrw!ER+ET&eEqFfcV_16eoqe}y=pX_ER+jdnz}!L*1`qpYDwDm> z@1GF@0YM$Sa=_^P4ybC6CrB1sl3T~KCIVgL_`O!0d(99Y>~-J4WKZM8M}{;P_*I_q zyKW;l9!%$|HNx%(HE%j$FDrr@)(2^L-2s2CgeK}Z!L?v8#JuZ_hO~6#PiIh=1Y{^8C36ne&yfSS1=S<&KCm;CWeUOwr#L*`K$i|l z6Unfu?xtjk!e9cln=Td5;2ih*wU+`_t`4nP*))%>I(5Nf zN^1lNbuOQGh)O?D(IUsDVTl?aHLNd#?0Gj(a|TJ4i7-xhYlgJYsiCq3CMyXf&_N0d znAz`oKP&@G!4N#LS>=!$Lw^3mI&wJ2#U~l(QoKF+q9c=`Xx(@P0Oe^62lH;D_ljm6 z8VKCE4{h@xRMwAYc12wXm}P;YI*?7u%M`A>8w*V&Ms}qn7uFj7sFIbxhn9($lnF=kKn8kO zU`!OY_tmkQoaseMR8SSn!np8ZJOCK2!QtUtcgyqTU9bc1B2kTPo^}H{+BJ5(8s({m zos4;G2atLrY?EYc3b|gOe(~dly^d!d5KGJT&0hjhOV6Dbb~?u6_7Rv9ET3Ayv;i`y z=jCi956pnG{PI#RhRL2YE*@m>*END$RDinf>ORUr@@tqdC@PL(K7$`?9bn7BaRKLZ zmsoY1pRKKmXZeml7C^bsu%_!|%f$WL%ELe{7}*PLMHttQ<(~hp2r_)-n!SfWoeP>p z-1Ty2JDdY>b>Xf4y_ksgi$%H6LcbwOk|N$|-J&<&QVi`Zn;Qf5T<#vFy(7SEcnueo z-{i?4m~(R$R@Wz%yR?}$Z|;F_WKz%<_X}u_r;8WuoQ6F15$dm7VuJwHSulN*^(eaN zOjIu5eyx6?_8H}2P3Ejb@5K6gQl7jhaam_v`{VP;%4Ve84 z#|QKc?I+f(Qdcdab*Gif45PDF);4qQTm7M+aD297y}vAv>A{2}+jbs%_Mx_cC7wKgZiw}r(=7s z{RH#XvKy_GbmTSCQ@m+0Qy4FHAOxxaR82}rfsM0y*GHSlU4I`Be;O16)^_j$7CnRsyQ3r}1`$!h z1U!U@nTiq%qDA(p9^~iVmOJjy(?JIhVSO0@f?C210wOuG|AV6b{6eXy)^{ajYlW1( zv7Q&^i0$mHNN9D?rl%tRe^ef1OD@GFp#%02iSiGd_0SCMxT%AfsV)Zj zG=3DYlATn^2^AtHvN%He+QM;=CJ=}!L*II$r(ipWFQ;8e-meXEGBoy>!JY; zCOmqz#7Z$grn~3_KspNhtR0$CBb zB^kk9LYbSMJ`ruf3rg9~vd6PLvf$Z!Gzm<}l;1e~`TJB{p0IDL*yIt?!AH&S)rCb; z7@EURbb%LT{hh*Vec6>zQjAVDr`J)l1OVX~H6CrfQWD09Iq`})`-TvKd}cc>X@1*c z9k!E!cc!h_;TG`}hRZhx8{PTH*YCeFEt$$^ZP)s9I58xdfLB5XH{tkVbws*XOGqkL zI&?Qv|7UlDpO7xWXrIPm^HAV08?2FiHl1H$r=#$y@P6l$`0ytU`iBiN;f-YWk%!b~ zDF`^of6++Oqs@tyG|%mulP;J52xRfz-=n5q<*S(t(^K< zxws*EcKRUw?0I7Oqi0y{O%^E|y{hZ(3pVD=K zSmStqdi>?3$MkdY?tf38qW^XwGjLqViB|eHEpz6yU#Fps9&^rgthN5HV4+pN*t`@i zh6WzWCKcv6BB$84&#|~Uj(*QcJR4OTFH5Zj$8NNiF?g;8SGJv;ij4^I%e5S4nk#IY z2%~YUn0^$s$St*&X!wwNsbG{>Nlf?6njpdNvpDu)*2DMjK4MuqL%(w}|0`~PPsYZH5QA+auHN?)GM3P?=dBwZL3+3a;ZJDk;xPL%n-A$ zY4ugV69L=S*vmo~;@k^fC(d2tgP0q%h0GHIwlXfagdg80HxPx#E^6%KH@%~nWW_A) z2kK9wN`3lGF_KY&VLT7trr(j~X!Buvhn@XSP^8NFpb;gqXc{xVhp(cpbg$1AE&Vui z)*e+hw5tOqe@|%DWS|&HOUE_rQ%kx1fIhovdH4wjCI%jv9O3B;s^}Y3rAjgE!%21h{J+lg;L*#_0R_X2`_kQ% z>kxW9JP{SpqBKHWh&*^4=6`ocq41K5%1WUW0)~Bh8i_{qni-rzc4ViW;e)DlyCH$F z1etx3QeGuWxq_^^LOM9@;=m!Mq9C^KMfr;1q#W~9Ic{G~&6mQEKE0jJ;K5f+deONWxf9ZezHzgdwdv56edbIH|z*kuDvG~M) z7loZ)b~<5z=>O*T@?F2gmy-1V#~0=!Ivv4o4?RUkfmJr7q+KiEe6=%{*sJoXKrbi> zdOI-CNnKCD!Y6Z6gh#YusH$cA>XSV`O4<3IQqFg{&IN0H@%K$K0BiCoWa2P=>6LlD zBnWAeaI5AO|+3nWV zIGW-&#lH8ox(=ClUbcYbh=H@G-}9}Wmf)T3GCWo8Ltmnwwd}6pyoH}?%PJ4ZhfvpQ zi)N$YhH*;e5xee2g~k_6i3Q@HZ6_$yyt6~o(V?v;kfn_sbIPxLO@I9ZU1x1LJ?}x1 z2hsMqeH#qyTNUS?8XK$Xor{zYzxRN4Aci(iOW;CcI7?4DjgW@33Qi-jQKflekp`vP080;n0j+XU;3jc`bj)dyTHmzM)~9mt_JYSL z1oTTZgcbY9C&IJ!=vV2B-p15K3M+Tz@HGR(kcb4F36_d;hZl>Xyc z-D?!#2{uA?l;)cMj=9!ydJBi}*yNWU zoJ%@<9m(|}$B|C)Oa0*$w6N)YJ}W6A5+HWfY4rvI7Ez;a$n8ci)(7%8MpB!3ca*rq z^yKTC7$)hRVGmt{Gs^G;lrHtE5dEznp_He~XE7u#DI}>rKfFHw1ZHH!Z{`4eQDAyc z{0AMRV2|G`u>rOZwq2hfEN-gO$)Y2%M+zd;U_;W=51j^QQqXR=QwjT&L;me28tv%t z84i#ltL4awK=Peej%0&eM0g^Ym?MtzcvNo+ z4DC;(jLN=64^)eOW{-z@mG!*Q#vPFS_Rm! z&K_nnd3+HNd#GBMdp?LGYmG>C#%eb)QLqRV%LA<>3DGr`mlGWb zXvtc2#hOBbxK~EDp<)wPEw&EY0&QyY5QWrcFKj^uw?`Mrbyfb$647VR2R1#hJaX!y z7i5^N)$A8yL>vSdh~?5!y3O5!N{6Igl4$OmzAQ9mD$}Gi^e$uArOYAn<>~bhiZRSWDfM`j~SN%3~;U=sx zJ?R(rPoQ7)d_>Hgqj#|(NpNR}D(bnU;dV}43-Viat~dO9u<;2}xp}Dsxm?bJK+B=6 z&FQV~qs(s~f>0ItB_jabNdsc8!D|U^o>VGc0Rq<>txjcmJ`UztU#3>HuN@}Sd$rt4 ziYWbSAnjXSle@R~SzpD3$>WbkJH@Zz`bwVvS6J*-Q8%q^-F{lq}cFtLd#&a-Q&M8fbZ*T(=CXizKWFTgCJk`MB9#SLE9({ z2WPlsW5fv%9H9)PTRb2PMUf!ZK0t?9Wf@X-(5>(NO`j;RD5h7M3#qo!x3{7Rtz z4UcQhT*!be&+L6GpjEB}@ni(hz>nkhFanA)q}v=2F#ks;Tf9J*Vx5!NR^z({a= z^5lgZWRNv(QkLCgrV#Vk9C1;5>ajqvgRMH*=iBqfq5^I&?0a<)gs09<4G09Qmt6c? zk;$5lQc{)hbMOU%C>oakaU-TYJ1dyl*MY91qXX^y%^imrgGzA9(x1<0ip_&l;Q2%+ z!0N(C?K2LyR=@}zN238OH#CU&tXU}2ffmw1rS3j&3lI!p%dNfZe)>KsJV9g6lT*(8 zoqyb1Fh8)!EdfyOpt;NegqO4|K&j4;LY#$^)WoWwGv{whaFPYp=b2-zV=G_;sf5Tv zeu2sY5LtwUkNd+urB_CvQ}@v(%3;Y%z#2z53lfG#LU-L5Cvz1Ac|*J)h5w^HZ1Mmm zIW+100{6J^F6=^+boL2qj|I)9H~+ZWeKGw7@)M~TTAaWK{S={67Dbzn2F;7AK8jyU z$Yhte_Uxv+I7Nze2RBXhkt4YdzQq*J{@AUj6k`?khSCh8seQSV_?uI(>ziYoE$%a5fA4eMlI8 z+#3*}o*s@NjkT;dC}3Tlv#p>inY^kJJNk)HdqX!~?z*PT%J~;LgCdc7@}6I$q4i-5 zWN#{ymK=NksIWbwNmjWQqc)E7?QEG)M`3fKt25%5vb$&vZPe*q17)(oD3B-93LJX_ zv1iv79qI}`FE8t%R}4a$Pwm-iD;4S0a5IAo-w zeI#Gu#>h)STJCY`cu5&qMjVFq)>?4C>1@P)<&dmUEilRu7qaURrecpbn0Q87#VxKS zKOwFb>&0#>=EcVIJW}W|sh?Jc{qksOaAmOlOO}`#QYc429*<-MU`RkNl@lY=moJZ^!8_ z<6i9uCeMj?Fw&k_s=`BZeXNffDoAvfsc9o_&J zS2EvP^ez+E8%7pK0U)ntU9ggh)&4Hq0i;1c1!k#tDkzgHUn9NK*Ko7*Ny-LIt>!1R={OH1YFJ)$yMaa>N3a^g_dFh@A0T~PjcaEWrrEO6^6wMryt&;6=Z?R?nNi+}u8<`84s$XPMG23Tw7Opa4CDrJ4y+f^C0Xx&Rtoq*Tkp zHHYHLq+Wh{r6t+cTZDO5$qIU5gUNiz+inyde{THpv&($bgJ&Ujn>p2=uHOOPa7f)$ z9m&mx2#c2AaslGMl0=xO3c=F`+=W2?=jnX`7UHN<{BGCFPlQSTf`kNfd9f?!7UPpD ztUOP>6#J6*+~)SlxYG5X+y&|Uo+aiAAF0C{dIlz;rRuMuAG1LdlbD-0+p6{BvYp86 zrgnIbV_J0!`tK2+#J9yYeDV3Eq-o`GllNBCB44d- z0k1{uW=W-5E#rRR%Y5Si&*1D(39Mem;U)b*-Ji0I2s|iK|YMsh+_xMftEn0Bo`=QaohQDQY4KW#xQ-TX!Mp?iJ;|! zs;oTh{?>Znce&`mfA~ZeQJYZzf}|k&ZuB#LA7m7vEfO8HblG>qAO-gy8~*ij5u4+q zf49D>GucNZ$VZ0{L4b82&9@1(=a+-t1fsT#flnq-7=B-1lm>>Qz;s6c=Mfdx=Na<{ z2M52D+L`un79lAaz}de&ghq2M+)Kq{e$jL~aFJRjG%1Cj)v&3(SrZ-~nVPkD6Z=nZebg#4(ZBns!KYRVVD_i`c?eZ^G zoCVm=VIpM$ifjg6I4|#H$;L@nB&J6{Z3jujw+(u^l;Xow;J#M+8@XEEQzhsnWfYY<92t27LScCPgUBMt0W z@&$%#kdyT3VQgv-rW-75>%AMyXD7&035Fg-S(KUH6z)sRrN<`E@FcLD;7)%g)rHuS zc!T)-1;6W+XckmyAcWh$SrrCID&HG7*UQ5eCl4YIrdj>QV&1}zgJVB+)F4Uijy%M#k{N|+wJ}vi+cOU+PKrF z>}-AVzkYh(5JS)p%c^`WlE!~7)RZGE;mcVPTy}`jzqz$OrxYs2{Pl+vN_lJJvOcc| z!&y&EF7(j+I1bZ+I}C~Nhtp5h3SEpz!}-9+RWEhPt2t5 zplI>yErY~EV2i&bHnFZ`jAneCfsF@*=CD7udqy{#SF83y_CnX-Q@+D{7$b_dCY4Vx zJoRH|qHFZwYb01-1sN?hXdEgJX@bjd>WkghrsenoSyh1aI$$S*@#31L8%uqRNb)?I zYq$ji-Lat6xTlRukOpfIl=CcLxF&!GsT$&QVFlGU@d`-J83e_eFXA{~EdCyA{nLb1 zc{jRe{DmnyDB$;pq-@vtGw{6HVUO}|zf*M<`v}TUm(Qv|lwd2#P2-4igCCG;Hwe;= z`Br~Mgxgy+>*PV1hY-!kiOED*hprFD9;`rxISaa;M=Ue!#)J2>NjK||e+5yPF(TmW zx*KhV7|gt2!oOZkhv)mav(u4LIrwRg$bGge?k<2e zygzfMjf9sX5$mXFO|Et%AQ+!jdQdPGkgwElk88#8P ze^GGSZp&U!*De_*4`~Uc46yXuf%Xp7zk6bP@}`;Qf;pdHcg|;WM72r-uULA?+uc;q zk-f}JA(7%^LXk_UrpCPN59Mr;@Pve1dWr$BcXxYV)FsqyH@n@ER}5`5ooAgBnvR5y zIxUzU*S$_L@yg>40||BvX+?SNOY>T$G*OvH0@-j);3g!5FrmSufTc2>*J z)|r0cdDFfnEiF>c)LmyPl(j+v}%>y-m|gIlcwBN&-#RYB`r#vs3c+@;Y0-(=G--Zrg|R+m_K+kJ`z5fHT~Wk z9fg(I3St+(=Pj+&T2Reqi4Pt5h`Rdm*CSL!eeNx%GRr~UI@Hv$xIAT?svD$nF62R! zvpzNnprA%-Qn7h*+y@Rir=}ebZLIosnmQUscC| zvdtv`upT7zIiJci!QWuvJ-mqy{`RD{`)#T=O0VYOh<06bn=T?<=VVnh++2zIrG=l8 zd@vQjAG8Hb=5b{7r7s8&b;R>-&(WVstan(%I-FH@?ubWgq*Y> z`uZ4O1P~kmzzo~a@f)bM%A>{{$HWhr?n8#Vb3kt7gw?^~))p>ss-m{`WJSt8YV=tk zlHyWM?|z;5*x>sH;$cz)DA7`nL>D(F+12^6fjHA!YDL=DK$xr_R&s+^B6?*0{3))P zmroZ5J;hgRdSs(FbsK($K+~zV-E`KAo(k3@!i^@}r{~85AP@BXZ+JRQcGzHv%k~jR zCQRP=yIe!I@Z+ozH-egD6Casya}-s>ewImSr6AQoAEJdWgQ9Fvs%>OW_z4sE>7UoE zvlX7$9~&_MqukxFmEdLf{7Zjla(fR^vR_Rq2;&-|h-%_D1P~uBryPv3lTCiitsIb$ zK$ipa{ag#--BKi3N898;I-RpH-utKZSrcDRka;hWWb#C#4(OZZV`}>verEw2;N^Yw zVI@d1(Ku_ckq0K3ya+s)LKXtvWr459bPZ<3<1J^i23n(>8dU(AYs~*7#_XcXAoC9^ z2t{P}-Dv?fgYNv|wUTX;7X4824npO?HgUGy6An~`dB#VArYTe4dpYJ8S)>9Jak}vj zBCR*$EN;_oMFVkKGh1dNUX)ewUbY4x#mO=<)idcEIJP@pZA5J#&)F=h*8;UB(jvR zA&4(E95FnuI}QF@^5n8Tp?WzBT$eJfXUm(Pry1|@9I-9vW0>XoXWaoZGJ@vQeF3h% zwcaH(`5NHlf}OO}^cRR&jk>g-RwD!+F6!l6GbIq+Dds^6qcE%?>V9N_6l(8F>R5>N zaq_J{#_r={?;nR3)iDaVwrj>tz^3HdUSRe7J-!i6Jg&0sN@N~CxG#HTJ=L}SAo-x{ z-B$XuU$;A8kz&SfT%XMUjO%mzjEVQ_5Ox1~+4W_2*;nN|JwMYnJnVmV0r|B6&if&{ z`1%x&8<^(rv2Z?3s{$71yK+2(VjBebX+Z4IY{=Gn+2KAE(j*2x2?cQFQbF5cN~C7v zx2t2@iS3(&Hm`f}UvM9abfOju9L2CgJzh-kUXq|2iYXNDwfV15?F7|D?%k2DsRaU=}3Vn zC&a>mmzXdO_>l(FVH2Djv>^=Z)WzSgj4a>eAJR{z7=ql5{?&#ELHx14qN1nM;(i~! z&^C#ES%1dZbSx6>V5BUMG9q2;P{{=%wAq$HK+7mHiX#jw1n zI7}=>=xl&Bh|z`$7I3$PZm2Vg>*`#(!ORw4DRD)CErVFQOgHas5O|6LgeW` zv94X>>@B~uM{Z8UUyL{fv0~ltt+zIQ_Qn2s;fa?pcUBwkEmB+`o|H|Hy967Y{)p#H z5@`kgx|anNQFuN%VS)FnWl}KF5zGll=zGy9MPvbtp=}4GJI|M&PB{LqhW2hstcKu8x{1WaB zqD(b7bc=;dQfhB-Mb9umva>W29=6@bzoxexJ>7IA6+%JGC?<~;VQBq>D?l>h4tyVV zuj%TJ92SbTav??HY?O=_ZLcxm}C$LyPlV_BF7b|CXHUq$<%%)YT zj7YhA7V>tqDK3uSxyn)a6rZr*%N@TdT>&z|PCak&1shNg<{J{o*A3H!=WD13KgeXa zXa5ziCpSf&rUOw@OqT2tyuC}mv28cRD_d~rx|j97oxpC%wK7-f5vymyX&XQCsOTa6 ztgLc*m;z|`p{*Ix7Y`KeTrJgWaxuZebNFw<(XqiA&}Rs4ZH%i-Q(C6z3B(EXdRS9N$8v}XhAr+}c^N77$Lq|w5vtIO93uqd(^ zbjyJagcgbLN2HaZJwW?R(?F8-XYwfakJRus+`4z&M9lqeAGB9X331OFUHDO)R39(T z%>Rgq!}8WcsXzB*{b8w0NhuV8OvnKR#L~wqDp%@?;gDu#PZ+OvZsHQ;j?46V*wm26 zcRSDe=%vKrLSIO@q18XNn@=Ag1*eeb2`Ix{qE&%&&N-eUhVom@ak(dr51) z&Rk9LagGtHAvfH3ED~u=YkJkZ>X6~O82ef=!A9(`LA_mq7d?IKkW6}ETx@eZ)54TUG%$&8VHE~MJATX`|bEUaGlf~9srYW<>=)+aD@~*g544l zqM84{N{F_-z;QnAY9$HIrveD~LAw{WSIvM2{6Sj?l>T7mc(5n0RHL<>3W)O`TUeL; zBmyD~MdqT9jNkTr4}4^da3ADzfe~=_=h45n0w~?iRQQ{ok60UGw7a5^ZE!ZoG#;B| zs&LJq%oq?FgSZ&1YRGBHB_D}r>v{GK$8sJ5_Ub=cOhCYHj4kj921P(EeCDu=c4nu0 zsDUCUh}rbq0H(>ilJw8HFu{(a52(-hVq2mDVcb={GEwu7f7{6JCeb}=kqXr);i zO*w`#J+e<1aVQ{Y8oFyn(lWHHvX|U?dL&-ltP~jg0D)4p13b4jQ1eyGi6m&$Jmg+G z*an{@VLd1Am}ztCxY8Onb^rdZ)9obmD#we29cL(5OE7oUfcE`&k#QTK!$AOXLz6<$ z&L|a1X>w=)!CyEdR<;`yKMK%1kuN6(Iwa0ek29=kE)eCm|?$K268h{Vz3t z^`Q?i8IFs*kI0OnBqS*TF>f8TCoSQl{!66{j_jmfGz{5;<0!ACeg$Uu zJlt2P<(!yap*d-B+{XFB%|?=5vo3wx%&yfYoqt&_fz&CVwxj)YJdW>A9zelIagVwc4=L%D(1<#DEg(hL=l2&q z5Ou1T5glaaZw=9$wG24LrqjJjFpd$V#X?t%q5+-$S9jUk0-Jc)n;)A)gSq$=aC@{k zq-_{}ITJLPgqJS>_NcCSeu?gQ>+gB)Fm1J0wEc8sPWsJ z+|O6XZ;yOX)q>~)hPjE^p{ngH`H3c(1TNLuD0p0@lAHoiLa%3##kjS1ws6pp>2my*M*$=^aEHAJZVy_rsz!q6); z)?#AA??bkj%1_D9w^bg{jB4Cu^G>z{690B4%CRTCP}O#4fP<#({Ymv=M@&O|bs6oE+Tt+Pg(t2Tagr3M=x1A}OQkEL zC9{pQxn_P$o1~x8YY5Su!rCYky1T`j+gmI5VMwNF0(G{KXR@``kw9ycmj_JNQlh(VMzHLFK1l%OQWE2I0#$s~7%q#>7wwi|)?t{l3s%Xgy`sFaN zcK_~}m5KSuouzMA<&K|4F=x9%)wS}bpv3PCWmLs5VU%b>aE9oeI|J|ze=$By%x%dj zgO{3OOT}V%xAlFs?gLx&#kwiB5m7n}OG-XYCReiOH=R#1SglPeG6kyoiasB26s`*F zAI+IglBr@SGp4&3b6H29#Auiud+-rAmVSbySTl|{f@ zD{)KO6dQ>_s=>WZFS}8=$b1 zl=V2Izg~T(c$Dnr0gUZG-io|5Q@@lK8hG;WzaM==b$X$eBZ$wwHa8{du-|FE0|zm; z|5(S=7Fd0f;;{185X!W2R z`%&-2Z+~YninIzaaEB)$|EqFV0bao^n@6)?!rz5K}eR1fTVaxmBB}$B7T!> zH?p(Iysb26>jlo0&=+{g0xnzddR zfK^rWz1*$WdYu`uC{6NGJVG0)z#$(20jLMpw4i@RK^WTKQeiDs)VC>syRI|$FhG76 zIK?iyDL;G%%mi!?bO&y-9Iah&3^*xt@7_Cjk9U8#k}?8rQA_X*>+g%AD~{Xgobu5sUzd`4bp(aK;+2ZBxO$CWz zzDVT?IbGBM0<(2(XH-K7V=OETK;SW?P-rhAK%=p-Q606=ps{T5rLhuyaUtlY?eF%p zk564XOKWk`J4>I=khJV8ZAZ1#W2KwBSpY2gh9NQyttD(ypSr&5Lgr(3e4*${R=BqQ zfPQUH)^E8VWodc^U3!(-IemIntrh{Z`n6qIdNiA#IDtMBCl1U6 z>e(jie1;N1-$i8d#2fyttCid558Fu4k)`U?!QbNfZ>>n6c(QAaTBkbmnp`#s`r328 zj+?voj+>D$L*b2@<`pN#{4KEs-mw^7tO;c~E;_X7IVwB}x(wqX?}w{#yCK$N&d)i= zO`=`WtVTb-7b%34m`-11Lh#0@x56+!HWAZ=(|L2741MuSRN29CkKW|Wtw`P=m*ni6jWiVee=LE7(g~yrI>om3v zpM6uC3!xG{mRnt7bIVUb^>=7IaXPL^8Z?eCC#OE-DGjjR_3z^r6AQUnz*%Oar6gvg zC$L*^`emRrSxnYSDfM-Cy06x9;O&%&Gj#Sb=B+ibs%DVrX|#`6?hK@C`qie|A3E}` z+Bbwa#AZ=msvBDH^G2ytTqFiZ)V~~YV6ppBMn$??XiiUKWKcXWdIf(hJz@O&|Nfyx zr3R^`wVM(Z*pa?H1N`PH1T7a1g?$+S^9@^ic5BRoJD?}@N~AO&4CGX0xbeD{Aiknz z#CF)LuV-rbIM2n2-Qvm@2DOZ9Q(a)j1`5C28`GV7;>%Y}^XUe_eRnx0ks{(#!&#g@_g*(&o_W#%l2l7`CCFV8DTFjAhuX_|kb4TvxRcenz7Y`R))>c_|6E~8 z{a3ii+`d}TNTQ>+H8)c^)0%8hsDu1^zcdaXqk5EC12ptGY+-6NH!lU z8L4`Y4N_X9@uI1EcMxu;@u)8*?_EoK}b*Z zM1FV}(ukOpZS_oOe7=BT!I&;dF~>M6sZ+-n3VmEO9V*;`=P%#jN@f=R8Bfq>POKA-OYsf zg;`J*m3_B^8$BxhTC^k8`qh>KD|2~>0T{iX11p@rImdyXD|pNq)D8E`=TLz4SWE!y zHrI|;F*Z&Em4wJW$QOX%t3_U}s@)su0ER8xtvUj$Jtyh@UU|1`7!Rfs@}T*>ZKkRd zEd%p%jYtYd$(yab-dQPLR#useZ`Zn;q1ao(_5N04Fxz7>@G(r8a__Rvz-9fQbEnbe z#8&Vnn!g|D?JctLAXBi3Tsa|jc=dwW)EU3)->lAxi(^IRQ!-aDk0<_t&GVvp z4Oef%Kq_nL5vi-QMu!-Vhu)-KZQo+In!I_y95tIbyrJPMTIH`Nd+BF4*DjgA%}l-h zHYM$C^9Y)`j$#rOZyX`h!vr@f_ojRwncZ$g3hI(hyeE7*UL$h;#7pR+khpZ)82GwY=z zyX6I(iP-fn&e8FAm3HnPvSiFKo-_jh_QTSnraCjL9*#IWpFnY%)p{`ydmdwH7*u;g z`1Wq`gAVF!D9M=6jN?wfzh|$ z7-c2#0K+Eh$+599#9wXE(fvKJP>sQa&vWK836Y6Z%7?!P4aQOajq7qG*V{lQO^QOT zDPbw`RKe0?6tvQjgp-={e54XzoaQS2MyxJPV9p(X!BE0-Z@dTSHqHyW5(UL)BwYze2O_^#2{fjANcK9Pv4urr$J z5@`45O`FHCCIg9CwB7u+Q?1=E`mFEXvwXLqPQjPucRHJc=I{N(_Y$4WkFxZxz}gC? zY3(MLTtIR^J$MSAWln~l^dkgOXZ6UB1n3Smr`f6Bba`Q@c}6;QbUE!G(Y>_VK?i zHjF25l5}hNz2y-_Knk4=q3o?3|Lm0L%5FX1AmGu3iDS^4zuruwYf9Y-f(1dqv#PB| z=X*^1Y}XRU6q`GAGo=`A6aZNtro1Ug<6Q5%)fy*Lm2J5jYAr}LpZv8%kn=bze8CJB z56exr+#NQFfO8&jmKIs)`_8ocluBHNk>*|ziG0M-$@1HOYQ73?-%@92QF{>tofBF~ z*yI;!jBs%zRpj9(@r#07Ygh?TsK1gq1L%1Rm&c1HQT7V(K8#$=f;2D4nCyF5rUpU+ zdZZ;AYi=S^W7X+d*f3pn$rEPC;YUCvs6(!=0UfWwv!PYON({l)25o-N1b4r{0~j4p z41MZc`t{(E^LYF5?+?zLw~tU0qh-;;|K*6%;=>@<8dP$l2?=m)djt1UiRP;H^&Onw;NZ=&e@!LM7zomXMI3xAv_k zlwyiaA8)K#=kAwF9kZ-@pY)~Li31Nh$d3I^Py8Bb-CEasN~P<8u2GLjWRMk&W(CnN zU2jzAGaIv}vrR7AEd}aTxoG4kTG0*Y_r&ClDD=|STQVldq!k!kJTCsUx0^pn?$^dr zTv-^l6SMu-MwjOIp!*S+Ub1?7td~)-E2~$Hwb9IlEx;^ht`94IkbNsSdVmt@{28rR zdUiZkj4KxEbXD0pM|3CkZP`1%cly{U_RbZRQ9_I)){j|_ve_cV7gM!%Yuq$l!Txf& zRyJ8>SITagD(VSmvd!}r2yLzmr z%y{-124g5perxy_awbv7ez~uc-cbaJ|pTD#G{R`;!M)T}&gJe*BiKyLLS0BAhnY(;-zRc#GWL{{k_!&Ef)oPJ&# zcvG{ZW}vGCWB1W|`4V?~j`lFznHTWHQB6iM`NePWIMybh&ON=?Wdm`Ovk)g&(NTMx zVm2>*3RJgj57*%TVNK6kRemvfAL(?x=;S4#h1D<{qa*@<;(VQ^&j^Syk@ro03iSEK zo6!0DF95k;1!TGP8rw}qKa?Idx;BpDL%VAC-d3?7)Fbi?AnH%iUE1!Q-7aH12*HsR zh=FIQW+Vq9|K`e{LM^GuQZ4;Zp;1eBIJscapHbYF4tzj&pbAPhocz|@yuDni8WSH< zA8>8>Wk-4}AZUIN;8OOMhNUj3rjHt%nj&Es(hR7@iVrgNzX*m+<+9+8pX=cyHUM-o zPon_rM^R?WpmM3$%^@%u!bQhI1kgxW%Sg(kfMG-dO+H!oQr=a*ILYoSy3=L(boTlUInhpi)vsE^x zIe=D+2E$3_=#t+=NE`wd3(Ys7wUiSs&2xtX*Esbg8Ip=HgM2|Vl3>n&xh4iK6AkKm^U3FO>C2Jn>I6MjY=%#ks zj_&|7L``I9)Dga^_E(&dI&-XI}A+I(s#L z>+ymd%acn`=il*Fd=u?-QOp(#eAsGH)OS^S=JDq&ee^m)AC&S%IWu%+^3tlj)1AaC zJEu1eWbuX+_*!!@`3!Ys9Lii|D+yy{ zOA24uypG0(yneB(G|QUtkA7!_UkYR441z!vtF}_tNNMM1Q6d{6so=H@C=)eZi3%_{ zR-Q2Zx=D;{ceq7IP!c?{4`1!x2=^%uR$Zcvj>?5#H!5c4CN7M)oCWs>8^)f6bD{#x z8%0}7Tiv48_b!LxcX8N+275BsH#}ecBT)C5y57)3Mf?bDh_SKHFNag7CwSA}d`hA0 zf9@djl3A$rIj!XP=P}qB8_Z6k`JxH3y==W=47TC=TXYV2*-s58rV^VyAFvHQuMy|? zv4S^QKB~;8@1oZ|S`bJplV4`L6tpDwF+*Rn-T7F?#&xwh3bDEZ(BrQDx^nia?uv@H z{*YP+R(8QNG!fK#XOZ*_k7?e-Rky0puWW6X^EFi1viO*7l%3y`n4F`(a`>13-GX#6 zaTNZ>ru!(T1Vw=&MA8_ z&08>x;RYRKk4^?F%8Hg2(Z9M*b|27@(Jacz@=)TpGCR)=6g*!6wzv-&G8)hc%|g^^X*RTb^O2H+o{nRzD}G?X2SW0szK zfbngBKhDg(f70pxdHPh=-@#Qv?l4ra=?~XbjjoF| zI4E*>EC#KPwLFmx8&;yefLTotlCFnfX^0!BVuj%Jsab^B!*a;+esZsOr440c599Vi zQ(n&6rbq!_@?4+34+{O~6w3ObA>$Su_K};euDm*CKP{#*Hdgfm@O`Hl|hXW?R5H z9)6!DAn;PzSajU=FS=)ZR5^a}MY(l<3{ZDBq`G`f4N zOCmz<&#p7Zw--95ufQ=p4L+W-GBSQbGqv0{(4g-fs5du+VTX8Bws%4D@0fiEKr`r$ zyr+tXqBrkczUN{gGfrh0*9tgE^h z^Y&~~rIh)%dZrCqYzhbP)(b`!_mH7vdbM)I>;sxy5OEwlyf{Fx8LREFIM{LYHSKQb zEN*b?r()=q!5w(fD4;Hk-vsnuE7WKdHj#WK()p@JHxwYn*}Vdl8b9Xx50dQ()Zbt|p%_|_#p$eislB59by2=m5|<3+@%TWGx@vBf%zTW6i^ApJX|!=8DI8`&{2VU;9I%45Tbn_wrX3c{hW34pU*RziC@C1j9UD0Yn4^fi5Yf&ch&{eoGGYX&KApl ztuTfRv$<`@fqqErAQ?lcD!3bUNC73lL!GG>-5JEZ;4*Z4V#~6J@zY412Q}Oll|i zLZgg7)nrxaw2ys$yC)X_+=+_NLOg71W0w{gDmzN8X2&F?1V8~N5kk(>OvVLuBL?6E zY3h*AT|e|J5V!ro#NrQq4>;^Tct`S)!!4H0lhj*+dNgpuE&EA4i_c9co|x-wTW;B! z!xDt(yt6D(ps{XIQ;!HJfQf-WH9)|fQO~MT@}9lQQ&=|nV5G1xQ~qF0-9e7{&=s`Y z(;4hVnYww5qeDO>Q-O-MRYF zlZgLv9mUur+BX)ARuw#Z_Us-p&=X8r;XZ5&qOi9xTE`S1Re-`tG$&o52YdoIx>m?89J) z&?O;4OUUPSzlu{kG}T$^6X5v6B*SNV*_d59Ghy9ypa%1h zmQ(ie(TdbxFUWvhIK!;u#Ygg;DujnAc9I`!n`iBBXciyq$XfSWl?PXKu55|odQ>m` zuCHSCDCy5W%Myu`gwMJo3Ng$b*8ky}v9x1*4b?EK>mcP}T@EkSVTm)6# zuTF1$@St=AjFVq9RV|Js^xo+jj~|7(&%0m1hIH#65=Mxr3XT=8uepK{Hv+tAx#OWP z;?C?0Cf{kX&H@UGu+>7A68Uu2ytHFzMy(X;Mh(@Mc+2Ze`* zOcaG$g6SH6*z2s?9@sKBezRVf-y_c7gfp%YDnQUwrxl2z(#<+DCBP=+rof&O$ecjwr%Zx;zu+9U z=`akh-g8XorT`(dZ4TVr6W}{UkD^XsE41(Nhk)(kMwbnJdw=*#2ZPI=*n=0!;f#&Op4L$hFomG@)k!Okc&8gnsSd{Sn5-)OSbx{Indd4;cc(blC` z&JvR(vzuMn-lY)B!?9{$-XdNo%BQzADHV#TUTtLpU-;Lx>gU8_+RyUcc|8uU@MTf@ znG*qP{@HOSrPritzWcIXrl+vZGm?KRyz>M5uE7tSXk+8R2CpycO&I~Xk{KEkr1en)Z6{a>b`7ikg9rFqoSMU5)JB*JeKHjr}U9=CQzMA@V+|aD0LwjRRM8lV}TqIj6l_+4=DG-iv*MEF~s@ zD46&ink(j`=+3k5zuMr~*7ELJy+n)iq6&M-4z zy*f4^u#K_$#l3Kzgs%hY-`+o!6yz#N+da6|?R?w|x4HWFYP2m*M#5_|HJy)EaZnW$ z0hFT$@c1{bo(<^Z{0uLnB-Ozug9W~0E)pq|VLhuuCrF|n@61D9_vwG+E|^G?rjA`U z-|V-~$?ZPap^4;lDDBYzW>;vd-P^cXSY#0wILiZTp0}7g_8DqT-4f)^7{ZpJ*d0#( z5ML6;$be67G$94 z1A2>5n@*|Jh~PJ<%GMw(sBDFtc)nEj@d>RZGiuOJfL} z1NT}!skD!|I1c~6K4a7qeEGyB%NCPB$mJXGtVuzbLJf#W!ueq1WB^heV$-iTo&ujB z83}1niEU~?2i_^rM%qnub*z6c9)v-zhBi=$2Av)wZ<_)xL

gksqU#xTea5eUYPx;KHXpC{!-x;&6<9~ALQMoOgSeocAH*Ltfr;X z(Q%?SB!vE&aN7q1>H0XE&O2z^;d~mv%@xD*BUQIjGT(bln^yCP=9{m}-F6&uXuvF= zd9RY%Z_Z3rZx3Y1r#ZroQPchT;~K5qUxfnn!4j{3MS<;(t~F6*(xn!n1F~Y+U_l21 ze#K`_mXe(-^6_#JACe>V!TBJSvYDaSY3fYMiPnJI6`#L_!5@~mcdi{_t(5UDqz%Sq z{5&1ZQICJyy|ALdja(|Kfblj74>z|ZHYltS5Y`;$~QIX2O zd)Gs5({jnUi$NLzql3f)9N4t98d)#8B zQ98G`%QeQqphvgo+G=z)w($_A;j_r(Idz~Ln?0)eD~%sQEQN-y+GgR>PkGITp{5)u-=YpR!~9clUd zv>wWY3UlAh{w?|Lh7(c&6F>R4kjkyj^KZ$0$RcJ8A!!4D&#BFVxV0($wi`&Qr(S%a zy(f12&@X)gY0RwR(t*@OsCNpOU!+^kQ>7IP90 zfBMo*5Y(U~yc&QP2Z5fg0;P-za)i?dQ1B9<8E#|$@O!9V!KckXMksa3w0)M4}RAqyLjX1 zlkIMB3rpRcaShN%9g?uvJ_rB#pLqE33~-g2P594m_hNsqANm===m$rTFY>f-1J8|V&Pg_oltmQoN|)1Pn8&?l-0{s&~uKoh|H-0(Pw7JU2V z#cR>Rz~?)6b6gIq8UsmP#4INCHI@D@FsUZNu9Bfq=O4)CF2FSwpNelx;sjDd&Z%s zn`_a{+zVNQD?n9BdP2WNWkPp)A3y?fU3cpvXiKX$zx}^&@-KdR$OlZH5B)Q6{l~X) z;Jap}@PGcv))y^?NAxNy?Vk35WUs$h>)(GEW8bysfBcq?6XFT{f9VB}ds7$hRZ&qf z0Nt<(ps;^Lh)g$c05Lqmc3|y2V^LS&0d3?i(<>(kePEMQ^=rw-4t?2Ppza&vp^LG% z_Lwx0)M*W%pjjHcc!Fc%qRPM%?jl%3`4T%!Rromva;TKMre0;xIw~hGlgT0)){v3v zDim&x*zf+XL+zaYHyN9O*lvf-G>6r@Ii}-j4pT2fD_idn_w~Wa-j{~g%x@FAgTDj&$bus!{4ypb6&J#R_*sq-h z*k_YDJ)5z2r&c<<2Mma@kd>TNk-Q_AG z;XRYp3_ZCMl7TMic~?7O!8iXan1Nolvn9p^AJj;Soxz@e9xOCbh;+gA;4wEP(|i4B z;m@HHMU?`(OetCE@BPh;d$fmA=G>v^@?%{G!W#<<$XL@BoA!JZx5;Ga4b z^c%_){R-Qz|LF!6yF?__jzK7(e5~i1Ga986pe;ml6Np?X(gvBuGz6ydT@kq=T9SnX zoz3q;5tb(-7po29!n3+&!$Tb1pd25yR4TJe%83L{3?RwE=H4SWCcvQ&Kuv^6l; zyb3EG6M;o!ii%8sk^TYpVORQ&IrdaEKu{_on=+*aG!WIkUK8TX>Mcv~6BxcIqUv_FjL)W{i681ipG-jy z-Z)us(d0)S@lP5OIq?GyT!nS@CQaFgYXqp`|KSXqH7*k#h$YI8dGr_3h6mhvqXI$C=lC8l?-OEb)_$%i!!?XK-|aHf~YVgU^t{ z;cI>|;F}C7o|W20Mu{D*)a6Q^xaJg`kW7KgoM{JyP9-EJMoMb)TK1SRJbw$QO>R8! zSuArAi3?k-L|)sUcGdx){i`Zh%#2C4n}LlJH?}Ae7#%l0Yy#F& zkc7#u{HF1M3!ZQFGmrc|s#W2r{+TC8GDrEdX!8P|OWmyu@+Km~Vl43((304J+7dBYL+KA6t z4Ub3x6m*WgYlZ*Qug;j8H1U_K|MpGCmC#$zI|;vN1#dz@UQWbSGmfn&P$==*1SEh4 zfv>p+%Bh3v7Y-_r8ld2_ft8Z)$g8GP%J#Dn#aTF5Xca~j{jB#~%R(MnaJipPK0O6*`C~<&E)Mdu}C($)_D!Ss2Ou51u=B*~)|MI=Hcs?^m??h`W;a>7<+NM_@sf zEo*M^_(Azv%6|2F69Y(%+~em!k;e6nt*jiFhu^! zyp=v-lj3ucXy2-dx7SJY z;Kj_k8f%3UuP0lY4${wlvZV6(6Q>HTdSwu$u*-{8fm>e@DzccFnK>{~ILR{x&dpJl zh3t_2_dPWbtc?KS|A|cs+%=}3*O>vd0)+0AZP84883><) zC>YNgoJ5d`tV}SCL~Y@mosI90XCWapU?Fues^&QhBeIM%Ns19Ezyjzoqgx0}IBE3! z_5h9j-j!ccg+S4a2!13m^ZqNPkdqSJ+eSSYla5Ndin=I0myJ7?d+ltNPSo)HohrPj;n~Ti zePD6$`qsmny<566`Zz9Y86JnbJp^Df->L227{Cn_rKJPuZfp6eKCGHq@a*wDHjT5fA~G4`88;)%)9kZ zda~yGw6OKpzSiHuK{framcW1D!{fQ8l^e;Yocp_<3|!1z)JoCcbaMMtAJPJ5HApovuplUpuidzm=Lmk zpFx>K4M73%_5mSAb5P?j?qJ{uPic*C*FcWJT_U&{^6UHRQ-(1a}EDrt1vF5*$9nsw-r21wRGwBe}zpPMYE7{FxIk z=p=Sgyl9!{xSQ*Fi`c)xywz~Oig=OH432cHK+-MK11~Hur8jumle5)6FCQeC86uM_ zC5?Lv#XE!dE-$_4>^MxGVZOg}S|dBjO!6zP$6Vcc*jy+!w2r>%vbp4P`^uAfw?@X@ zBspeUc}@X=gBSDcWvADANv~z&JQvMIsp=Ui35{xA>wXCI{+28W1(}EMho&p+hSIo|#IC$Ew?>eI={&Q9|JFX#*=? zX<$JoASFSun=|rC9)P39%1Gp3^j8aD%Et%dhGvnL;LMol|Ior>I7J0_?h|p4o_r4& z5Zq*`dXggRp+Uv~y|+#s+C!-4BkjfhfkdW2U}y0YrcB|-ci zs)8U;tCMqz3mjd8f0t+h7-l%sUmw^I@vqzUI1fw1pti~zF$fK|F*cF>k>It)F?`7Q zc=ywOU)KZ7!voMH812)JG1G^|9m_C@U9vDv zehosr7t(Z~FfjH522&miljfwIK%*h-&d^aOo^`q-Kp1wyGZV;|StVpO5%)dVL*UTF zV=9jzpx;-4B-i{l_%ef=#2^Z71);iupzhlvAU7a&r_uvwwmP$q(3;Isu&}Y!Ju3l$ zrt+S2=|I|-p6o0|lw$qSlA*#$#qPq7BTEzIaeXlgO7$bMq#G?_V@%}w7A6vWI42^_O zb?mhEClv>pVY(S|@3viNpvqvmuCZz)Ja;gJmKDmV_4fd~&dS)(8A-R4=H%yUADiFx z^RH>VX%en9rOqn5G39(TcX9Pc{&y0u4A_26Or5bDvmmz=T+V#z&Wg3kwcr^xCm+k1 zXww;e2>7g==Y%2kk);9eDim)t*x#!M`F%AQRj*>C0ad0Xb*+qFu z##p4^$Vo=Yw`7ChG{^N)cccA1hy9{9q88m?f}EKT+EfC4;(7oxg^%9bCrA(!AZ~XXaP>+e3g6h2Va5~z4{?d=D3Fin$28KnQ71V9FhMkkbbzQ2Tok_M% z2-W`;?%8;dc6P7?!c%y0o2JdsEPr3<+so*9`G{lu(;url#0%eZh=7%%0svNDcRl|m zYG_>!V{x-*!xibt z^=2haen9HJMFf0~JZr8rv_Br_@IC(yd;lY5;}=Jscj4feIGiJh@%`MuY z-XS?B5?VNYa(%b8M0BPl1n<>O-#1yGmP!KU8O{Fyr{#U`o^akIiluWX`z|71O>zyf zmhmEz`9jYN{Vv^uhei2xNaZjH6j5T_GdJK<)|utuMJQwt3t`8ZGPC}!@Xs|v;rmij z1i$3bY;iEA=+KKU>GHZ?xWG4(3*32pOLoFBXnrqHm?M(o;>@rtfNh_X!;9O}U~*=d z1ltpL&WS4i1&-In6pW-@IE&){EQx)JUek_+F6=+ZWHs)KNC#EVe2BzV^?i&1= zF?rCZ#5&AApZ=tG8a$i~pyV_{0&A0U68Gn;A;>;Cw0oVi9TcBZ;Ok`w6oa#Z=b)Xl zynXtD38A#|z}C!*_V|t2cDc4wt~a=y&wjFxQMA&a15eoO$9~~ArbK}ehm0NKz%nr8Xr{c)q`p!&NgOS|G9`0~JTapCyw@fJh9$ zo`&ncn^{Kaas2M6HwFkfD{haRGzZ1sWMMo>+wc43!v^+R4oI|oinG{NO~LCvJ!=BU zh~arXF0#-78ByM9ah!|+&OJh8W{L%IzX_gyeR4`)U?Y44bXTc#jEu4%(=q_3!bT9$ zDV^QfJQH8P=P*z$t0*{~uvd2LFb@%HSI8$%vUFYlUr1&!ZBMI#> z_pxAhzR5$`GK^RGT?D@RR)z>yn0~*Laz+*0F1pD^zXPwldN7&r^?53aQ zp#jn%!@s{gM(4_FNnD?KyuM~`D|WHaXV`!w#k(`Mu+BcuE$sfmZ7%>eEKko@_SW}% z0A%yUbzZD<9`vsAJ8`{4Q@D2$pC;(#Zs*2;R2{7AbIX!$fb#YUw-8U2-S6S0kY34< zNs5bcBdC&Mxv}6BQH8PJ<498DrDO;wfFLmaJ?dR{az8omBfyy#@%0- zR;=wCTio@W+4*QoDQ$IyP$kxm(uuv9o;f5H0;e*>&?YEW9^P7N8r-;r%zy5>_DwW=v(TWq(*aaYH)Qzus1%*`|496@5TGk zZB6CO9i-+>n#NkW?;~)NWZ-?^ncn%#oSeL-02^lRTDO7Qrbm1gwCtvFCj4M7>2j`- ztS{$?>9TuYN~!PO!n4d}lKK?e$ioi!hkhqSFi=xf0L%OuP(3q&+C2gp=R2QaUY=1B zARmWK_SmBMfgRG@IX9?+1icprpKP^Y)Q(63#0DvnNY4T^*OKG*J!)@4xqca2^|d&c z(fo6YAjNk^x~TCM$Q^$O$r!S?@r$KXn z`CAwp3$omwowWl$yWD@qdFcP5W_{iD51^ApLLmWQXWhH>^66(Ajo>3-~?PZL0 zx;zgl=HWUH5nU6tSLc8-+|g_el;K1Y1wqo)CMbVzKzsu+DF1MuoUee%+?6*Fy|Je0 zUIWzQ@M6r9ci?x5k|Cs^cn#EWlmLmV09gvTEo$t1H3z_m0g#2?0XH`hBd>D8sx=`~ z%9oDT#>`KFD=|2wJTqif7{lKLcGVHMn@}X4K`W@S22$N`>#M!%5v!5vf>=MkrhKG@ zY%FrN&00m)qp-U}`O|c?w6w396eNE>g~Tt!v*>J3h05uRcK8Wz-uOBU*iUq6LI+yw zu=aNyNyrDtWe60FLNEHwQJ7egdX#nyaW@b**6a9@7UiT9I1r<~OE3Q3F5h`7&m7Cs zbZXe4c|=6`D<}yJGKE37i^vOsYo?7D(Y)6RZ8^3Il7l1tWe|nYD!;Rb^!o1^Tj~J1 z)kY3nq_{N?hkNCbU(j6Wk0wvPwrv7hQtj=>M=84GixoGq#^HAqRuCeAxC$S$u+>i=ZGKZaj+v81GUg_iB=@rg)@WP{xdLhD)JZX-6kL&tho?uq33S(qq zk_OI3tQmGj{|nOjZW$6+1ClZb4-PmZ6wY7--7AcZj3Ni2Vf>SuPasfe4$ZhB9KI-A zPFi06;^b#+k#kmrYCK#pi|f6{$O>Z`)ERDA8?En^Mpxdz93W--mSU$y_kpDj$TIi} z4d1bEwwTI90D%d_M4wHX1q9IF!&@d}B6S&D+uU0kOKu|j8g%p{+GJBgR{kTIiFsHp zMs^1e9kZ9)k-qqXe-MzcXu2n(w-~$p2M3&uyCG)HmN})=0?lskq8!>sUZ)DgPoSv@ zpfGfhd`?~fWJ^v2XZM88JIWJyaA%L*z1roz=KLRe0q9cUZ95g{z1#Zn4FD`|7-C#a zNmCY`h%k<2QrEG64lrV+jwMix(P}|1ANnH4a&0R^yIc)<3y}rxq?HK=UcKK_8h?#E z5dUgmijl8tJ`SsUxcJzA_D^DN1mf=ZeVqmtzfi4FLj%S&`Dz0`W`c~=?{1G=D!YJQ zYX}y&@snxbh#63=o+x+f9q5&v$$H>G&E)*ZUH8akziw{rPP0kqLx83f>xj2#4GNpa zy*H@q1;@Z1-qQWEv-M%D3=#(QSc|A;4S=Ok z$Z+w+U{2$*VSKahRU6i81*M|lKvgh24<#=!d44BYBD?>w%81}>eTsA*Fbrc;ktN)O z(G1uQh|TbW1tV&}EqkF4Gw3V#P1H|^W~msrgAw?G>SB7JC5ghh@&5MXE>B(GtLX?6 zP0fc}Uj_gV4~NYZlGD5aaUN`_-Ua0%O&W@-(RCT2UmVY>N6CX+x66HkoM&ua2Lv_Q zpq3lS-$DJYV|U+oF9E4f>xMxbguX5Mv8ty#ocbCyD$$;@9+cYBO)T=v2$*# zp#j;~{1LVam0m6u*4f|w>8+%trKgVF`nz(M(tz)ckwX|)6-v6T&@=(Q6Ki5+5@+2P zZoOi+e8^{84+5FQ_cLl#^z!oZ&p}o18=%_MV1n?9k%30X&v#DP**r#A_M@W|r+=Tj zLwTL$hmwGGIlO;1e^Ry(D$f;VhlwtxrC5IV$6e3Wb(tWVk2A&^?-Isus4dtA3N+`i z@)msxOpk**e!KwT(Ko3117HLbHZWko?d9!o9DW8r-`(=D39}Dz!+uZySOxDtOnJun ziRZ(nsXNc3#|HV;-u3Y};|s*(@Qg^=aMu6C+$GkJjK+H#+Z0S2d*$GhxO+IuLs70@ z=RNY6w5rYf^_@CW{c?pZv6FKomT?~h_H&e}YX>t_L{ruC*L|KCd{GX>1wpRwDN1>? zF%BhaqwNjM+NXKv&v>SqUaZeJ?bMgk9vk?x$g9%U=Ihb&3n0VDcvm!|m+?_PBef9r zoo!tJIP!2IsDNZFk# z^A<(gvNBe67dn+PpC@EtuAiUmyCXRiW2hkM=PVe4DTrpTW6EpNbd-q?I%OspbD8qO zOP;Tp6S;=iTqdL{pXqD-TCh^Y_gLf1mt2ZF*NoH6zZXrU*34AiTvbL^1%X1-J{<|0 z4#ZF1Hc8%2;d>ovmq2lFr*rbU33t>^OCuJ4Ks+Z4LzMway!fZ=t28Ih;|&uT*1x;F z7YQaE&J@o~h}KuncNBCN3X{F}lJFy2IHOlSU$2$%gd@^Y)In%@=%mS*k?RD~H69}V zVk^oZ&956`gd6*mRfahh z;MPjptC2kz!s@9Jt(oUZud|>LW+^E1 zZRlX2Mr>XWSq8gkv@7>%=iBZ*86ZuZdk}BE((G};f+=9zEMyz+*qq?T!EeexMwotw zxbx3in~)l4FuV5uqb2y~KbI>|#|e%>M?nYd#Q}AY)m_nj4)zeTNeVFuDsN=sClwLb z@d-IEJ60}lt!$glt~uO&{Ga~Kll1TY1$1dhxOhf~HF?WNB>G_aR8-H*)fnG~EE2Br7#fVZx6 zmX`Jk*m}}{LNC1R!G2@^TmKMT%0P{;`8%EH;$~}1O}SyIM#&OOX5U)kV+u7cbg>rq zvKwQj3=KdiShPn3)_g$VUq^mym?a%?w*Fw zMfA0PAv;-ipHfm}UHiH?VJKTlo!JM4;Wuig%~j0W7=apB@FWKUAeZ=K3-SHgFOIaj z7%5;!tj+q^j1~Q>_|mx%g`#4<3k?f76P!Yhz_=y|T4y3Mo>CbkV3)GJ+U1|BaceH> z;yTAk`k}&mSP$7h4hiroyj?dPyCZpW{L|HHT#II&Jyr9&B^c6i^L<+M@DGi3-KU0& zwfn{vbA7R;8a0{-N0a? zD}TD2M6n9I`hN*>P(;kbd05qO(*Bd_;|>5!X^_Q;5O^GOd9JMlO>S5Yzq5sY(aqkG z48E*QkF;5*|4qY4tpxYktNpktv16^9bI&=m-HQlvODF`j;f9f@V)tmmQDoDR0feqM z&Z??{wP~N`JKYvq1av^(gEt5XD%dAZcVUXRbcMZ_rp5PHUY1dXO3ltia{5TBgicP? z;S3e`a(iGe7^(QeS(LBo{t~7$-Ws?)X3T5M%Y-IQ(|mz(i0)#!1J4dDcxj{I`fNpI zdvx=e6Q@D)Py_)^bQeY)6`)Xa?6z;)+nf#ub>eFPnR^bk$5OFC3mS@!KRX)JVep~z zU;=0%n8Zb(f+^WIBaFk5FOQ)?M%Naw0@77;5np5rxNGw}=PoUTABAoOv{pDD3&)AT}a2 zb2gnZlUT%h_@)gU^K`QflQm3UfBF6YSU(F|pA^|yCm=I|4WG3mm;Tdd4VD~;{0_@^ zjG8&Xo19TMX;dqq$4BA@gmBowK`5n=n?p1yQ$@CmFM#>5D>0=B&;|9L5>3QH#bSKld5(U`8 z^I~V|E6b!Mf82Sj5BO7pxG!;WcYXze;0A1G{@nj(kUWnj+{-z_cO<8eY{>wBlA@ol z(EJ+`B0#U}My#SjPL7-|LE`bai|YN;POpyukDEPShxIUm6^&eBRCQZ+o&9y|>H#AI zgBJ^v`GVW);}oHFDiX=iwgzs@ zaFhLQKn-H|1kqEiT#DZkq-sdoiAOrz0BBqu457l(TTyzZB&D-Ba2bt_UJHW)egL|% zi*t=fKZ2IL8Uz>a_RP>dqV$XS7>EB2yMA$c?0!IY)6D7rGHR={J01Y(MQ}dtzP&@2 zC*%XfFa&%Fx7P2lVhiE+AIUIz`4KbQ37{Kdo1~wM`6T0?23D9wPz+E=;?n$JaI_~y9Pl?P-o6p^^J)0S8f8FQl zEejigo9{!o?|x8Y*6nirXiKH}M&qmCq{9=e!FIdhQg+*EqLMCCGxxz{`fJGtc>}#J ztTnm9H5zT*^IC%3<4Rh3Rxdqh1sop5I^g;g8>lrIVO(qIt8{!C3RasyF5)f zvC`1PT=9XznU#NbT{9KKc>iWpF9%Nskt!#%=0?W>w0Vkm@_`bp^;-1;t%KNV*T8@| z=gN4Bjx?P}Jgxl3WXG?MW5IjjMkmp*pz8l@?wdkHKm-(BNn;&x-X~EV^GV^b_f86S zJ(}?+8`ya1uyR7S9^Lv*6Lzg5d#l!;`8;J+OUzfEhKHg}d?U{p@B?g)i|DpFw= z7e(`N<{waf&kWX=?57kmYWmV6B}bb@E6QwpadKgI!9DcAaTX)B?DntN!>O_YZdGa1 zn8RPMSMTd#xXwZWI_Z~78Bf52_&RSCU;pU%zI%*}m6KIMvH~eqSgo(gR*db5Zc!Q( zYLd)I?mF}H#xo9%#@zP{_a};)Y{%n<$Xo(i#D;&KOfhIVy4sR#B8^!a9rrnEGQ-OU zn5)6jVo>4s6V0G*v`s-k47!{8SUX2NMh5cp{iL(8MIC|bAhe?no;haG(XaJZi zGR@Xqv_xcwJ>Pz>L5dAnl5oU1W}f(`^F4!{D|US`+B%G38v^3Urh~_SrLr_H5HfB8 z{LmwcOTba7s=i*<3-{RML4b=PT2_P;(iks6VBDCL zX++uu!)TNB$i(mtghUM-X|yxxm?ZcslGNrbv(nz!IE`!eO##CLm(O0a)2w?n|@xRtmV>{#e0A_X!+_S3oC8 zjTn{S?}vnv8~JL-D5KdP9oe~pG|10!+I^D(Ek<*Qj>1-jCf3Pf*bTNR-@sp%xs$&Q zsN$jah-VCi26*G)CY8}5oheM9i~=nd2aKq}ZMxvryD(v%J0|+Sk|f8$M>Z62!R>xD zoAvqM{YLpKg=l9k)P_835jkPqC0Ee%4|2!0>@3r24evgggd4pYvZnQD?cn{^E z<1D2}W8E#rjRyh@3V4+8YF||0PJ!OoH}8|4`L}>Ma}yZA<`!-^U~?<6K`iP(q2vPf zdM9-S%{&BraZ%W}q^O5O7ul%pjaKb{sgr{C&$V=L(dGeVx0Sf7)4M_8=dwFYJ|JQ-pceKAQyT8}WVn`y)-5JORka??HdE&Xv zNA*yKSaYxDLg}7k4=a*LgaKp`zBTytdLpnJYTz^V3E`fzu0Eh-sjIEMP{9L2bFC=> z{;U6mTn2swC%q0iW@T7%nQ_jIQF*%A0i~d0K4ZrnObo10G~+0$9k(p9d>N6;Bt)hS zzD=gEi=54I-W!)S0f%2WLL@8o!0zglk|oe269GltZ|11pg*%3~6Y4?mi^%vEvrHG&lQy34 zjyC|&G9ceRP*>yE#-AprMbT!pJvTsU58Uf(xy&DFhr5Vu%TE0ta7%arz%2#uPl9{F zUj#wvBmQlXvT{I{Svv1502QE8l_!a8cj5E@ZYN`yn*C+|e+S=PUjM&?@BcdvzRPEl z)kyLebnER`4DN8&XGDd$z|vPTpuHlTbws!);zJwra}LWB?3mD&*lPeqq(_k6D9F|+ zdg?mW7psWH7YJq7q92~bj|D?rQkU7c!}0R3z$k2d>a)JOBlabU!$q>;V*gNKi%$bsuy>pU;<<33AO8@?Wx$O2H9zqAAsYi)-PLnEGMNhLreRTe^Wb@&+aS^-S*7gL%3{^{vt+gt0r1XH*&hKh@}J zi$>`>>?MN%`p~xHX$MSA8cN0hGL``73tttv&>F=bj_KfORtI^;p6TUuxfp>KJE<^x zUBbBk2UfjJ5yyW1-?8e=w=qQCOtt;_h%c6>Z@#?;u@aagaUmuqrv%%x>@l9yXT-ccQP7>Jm$ zb|)YmIxuUm0yt>C!3rkiwHqU-Y zCiaCYzC4Q(gs58%(z%^7LkK2((PuHp`b^6rxr|v4ZLX6vZ8ev9YX-+qLst~Y;EQmHE- zYSa&2{7=-qiH+b)F&Sbhrw?hMyy1cdH%~HZ-zY>?Oc=d?eomPNUpxFO5{TQyQtbyY zD;zZO-_Yh1wjTlK-__W_sznQ3fARcSH`pyrg9U$eRWNOL3g%ml!iE0#GJjYvmso&U zzdKH^l4JP9X6hIE%s=oUNNGi?hhX!qd8yI`{WkpIxzzNsi_!|77JD~jD&cW<_SV?+ zBSgkhICy%cz{B@Vk6~Bd6#6TH=8>GZ=;_&%Svv{F9sL3kyogJRR%l;Y9Q|6IJY$6d}Yq z?iZadn(-^oy39z8-tn&8624-eAJ~g7H3S-fE{grdn;*bIrx#tUR=j-J-YC*Y|I%Qwms;&NeRn#Nj=(EJu%cComNn9G3`s{?C`hBWXBm~}D7mwx&HBL~` zPtLM$yb5Xem;7;1?If}?0i4u<$A%F(Zx=xy74J)(t)O5t`%5_={0Mc#Y~X=K7Dehh z;WeX;x~q_;%#qpE2VbGZDGQ)&4E!j2pqVzDTUz&V=F#~k0-t@)C2ADdr2`fyqj)P* zh&0#RGpc5s>Se+M)cXxqC;O8b3HpSRk7GSCb2^`=(*|x-rkgahxH@GRhiZIuO{tf^ z-QCj|)kUtsT436YWoV)!b@xwu*U-zNHLfMfn2J;Ewv;B%&%O}lG1XT#$2QsWa_zzw z*@yHY5fP_~@Jtf%y=eL+U8ufC_>=CX%yO<#Tv4JJ*w>@L$G~q@n!p}bf68*|iFC+2 z%N+ALTDGO@qK(OxQ&-pumOf|}W95$BJ*%Gk=aqNS*w7H;OoVJA?v=a>!abzb1OEw8xQ$r1eOg0#9q%F3Wy%D-hLRo z?kk{#Vq{!_$Y5yD_VJz|Sb#;i5&)t@^Zv^n^&SiYSJ*_6Ya=AQc|0XO(h1)4U55TG z_xj!qq-J5@gLf4P zt(3V@LZ|?UC1EQodok_-2YM@mK$ajq^6a8TS3CVH6h7w3ucv#ALSSGVtSp2Rg-f)# zzsAb+xG(6aSM=6k(tFc%1RePRC@48MD*(Rnj^k|yb$+iTWu}+k$mYv6` zUK7UFO)no}4PN@cO(&et7NK1~Fuz~KeMww<5<+7IJ;7Gk5RTWlyHTH>_8{|(0=;hV zjy%ac{LkC^K?ik+2=aRC_egG+qO(>v;sukrn;{yQh${>1zw-uifC`kB3UI9rTA{-^ zeI_}SYCqc711-mY-(k4a=fDEU@gM_=TQ3m5K4T#S$~=@T{c-5mngb|h!eIK8fhkNo z0ya1|W4?h?FV&&({N0mOM`n+l;ws1aCUGBVN!+&r9LCYGy^tB*YWR9?Z068L?fE}_ zb!i|q>tnp8luAI#`a3iAP5$9z5u=#0d7xP>B`r?hsFmf2=U%aUgavN7wu&#zkN)g! zIMANf5V}Gv@xx9Va4{-bNc%z;^bSe{Sqz#(Px0RMk;)yXtb$1sGQ zn;axKrr2c#I}KmVM+l&ka>M81Dw-|rAvAKSv3u|>-jM!&vWq^cMhg?Api$F8Uy)Zu z4c_e#ndfq&VDo4N1>?q?&D$5{pM}Y=Vngk@28o{*ZsQm-5(DynY1zrgtm$^#?p@-) zW$t;qg#P_|mXAx9ntGOn0Z$e6TW^`xT^&*6E}}Xi0k$puyicyiLSA64Z2}Rmokzj~ zUQE7;vbeJl<4o9o$G1EdPbyVf3-%v1VDnEZpXTphb^J`A5&yg!y~^u6O)v04qY?@8z373@$= zO2uRLYYcjkF9>N{>lX~h81WDB)&<|QueB|5CMq$xWxVR9z|eWW&y!I|BPj@FC#~Ja zD?tZRVzW(am2_&Med84PJIR{$U+8(p#H5cY=%Xt3_N0u8_` zU#pj{6Z|M-47&X?Y-LKSHc+HD7M&SypFZMoA} zNzXhWn$E6w5)7AlZ+(9Fd-Tj!!Z~^?;!sAP;AHzGuL_@w-uje&gkBHpQ==dT7*XR< zXjQ?@Hh?!!xZ|O-slb*(4#g=DdbuQ&9=2KcXs#E0$4`WKzE_fqqWCH9@f7$^2<-5Iokem z%Hf1#ns|I`rK-Jt|5oGa6)yuvoIuPQ=tqPjIQ1VzJq1yBy5TKY!dwAAPmi!^aZ(hY z$*!ah5mb4H7izid?S=x*@Y4vp&Od{MRS5>gSXS6=*$*Ox5FE(ebwvc0ZLq27BJKhC zVRcTME(?SKA1E0QCv=@9!?^T%FW*7D{^_?YFc}peErTN= zT3F{1Ih#j^=h_Sv!gr(Lv};hK!CWj&0&AcQYe+9rfekz@0+Z2x6l_D8K_%w*vGSb) z7Z=y<>kSVnMn;b4>%QLlU~rQhH$uUnzVozqTRsP})=m<29^GtUL&e5_sZh?Nlp=l% zUz7OJ4I0$i@5jz2bIp0v#!bi;Pj(OlwU53*-#ARH^zcypQ6vsQB57?KKlnpQ9bxO_ zz7gxnli9t{BBQv(|);QwUzmOE)w4+b~csQ-M&@DUY&z>x;&88ul}O@{aqX1UbLTBWM0va!fR@o+M*zy&o*t5S%aTe zkeUvL^_jOT4xJKm`mkXu3S7DRI_)O$BOPqCIqmfd3xB5iLX7&eJkjEJWZi%+zh3lnUGCgC!b88ka>gKg$m;PC14c6(lWWoPbu` z>**t=8{=AbUm}B^U71YzOxGhmewihy9gj9{HM70=SN!40W$mtG1_#aQc3iQNR&Cc{7LApq^)xNq?*6Rij+T9l6jg+MWYSpVn%DO`rDg#T{|+% z=H3f3YPxBd=TrpdxbeO-Pbu!%<0U49&srZnH^U$6Fl68TYR$1Gj4_wpbu;(gUHVUL z0SIBLfF(Qt`Uywwanv22rfNQAE^sXsd?{cz743GUnQ2$Iq$RR^1MQLwOj|le?ZdII z7#w|4jZHn!YAX>>@n=z2#=g<)Zm8b25&E#-t1J8~r)Y@7 z-SK1~&iGHTW`xr2&9mX0DICPcF}qU+R?RNL)`09)lcqg*`aEU1+f$U!D_G4ZTz#HK zAXG7{Y*g>#uk?=s{qhv$zdiwnJ(azqMJ%RR1f-Plu8+%EJWA}HS+{;HjCRz7Q4l|U zo&7&v4UT=v)TUZEXC0nR^;fz!eM+eUqN3)HMHI))Zaa2_x9bO6lTh%88o4_}#_@() zK(Uwx15y91@{k_0XtMP@(uU;+Mh~M8E4CD=|Dymw9^TdLMB-klonF-R+-hk1iC{xZrC0C zz|f)rwj(B2JHe7&T>a;nzbArNDJ{Iz;>jP9SEqj*-hL^lzD6{Vt5(}TR` zx%&j{L5F%2-l1k0f}yT~*AYEhYCf!Y8H0|#FcB{oRX}h72VYU}^gFRS^9J~+D4`l= zaF0Gahsho`#D_q<82YR~*xogU6YnE-yXpV!qjq`GT-_x~ESl*f+-SFd^SuPIph*P8 zNJAIfJ{T;9sHRWv$Fo#)XZjGwY+ zHo3`OxusHxkeNL)vRB;p=$Fjw5w}rBR>lDr^&|Y!1bxlrd)3vPf40s#+hs0~8+9vdt z%wX9(>*b51#e>8Im5V6>8-8MDG{xoQs`?F^ zAuY0o$EDmhhyw|%R(eZ%iD*_!_VyY7fFN84GOgEV73afr>^8qobd%;>NiI^7(tFoM z9fN(Q@Zf_5Vg@xZqNf??6!Um7^&c>`1GF4!kXDLO(aiEPuO06ILfw*~1-IU?aB zy&`D5-!Bedg;F7pyu@I&X6Rq~0~|h25XD@@*WU$(o);Qt0~1MrLWmk7%BtLu!MzyL zv_V0tXA0h~MLy6o+G@^E(B5Cf2n}q^L`#Y3=oIFlDXTP4%MCyu-?e<913gJ@XR+EpvF--^8{aQR*P53^7v4kyWu{5v5m=ALn!D4a zAD%h$F$tQI0j8e)H5fEQfci-jkyu|xUCBCF*1yq0(_1k?W3`h(- z%u-+qeAQI*ql(_KMP9^sKKNjf!uL5_YctI+XoHtEx5rFEm2jkmS za5Zt}1CqGuD}hODE}cLturMtQ0}W_v+iz&8$nm*ca&sw-0TxtGr|20&+)rQ6HN>9g z2j-=eFwbp7dBmsG7I~~uiI)|9-7aN!WY;TM+~V&>la0FbREp+Yn}(%<9+FO+ijZe* zR6pR&3N|#>5_LjKp9UvYC0<5US%ZE01qeI!fu_C!IDasg09B@NLN=v?(#NXD4iyzI zTj~{`GX3ODrU`OHjjmd~@QLR5D{i)LZ*CPh8tLyu&}jC>o>k1@z-m9V_o!1_|Jj8_ z!ebm6l4(H1J)vL+533cX(I`3X3;Uej?yt(~^H;Z?t@{u;XA;;1E#1(i}>< zcyM#0A71|z5pPdk{=8C3@X_UTq)13QbfdQ^82dZIRwzT8a($`+@; zktg)DPN`=kjKkRQv*pU8v6m0mgz-dqU(Wq~+U^(-Oe{K2wBU+oRth>6ZDQlcyUS8y zF;gBZkLpKDJ=_>RZ=OQJ^^e{u=TPW;L|3kkjo=5eN^qRQQt8m|LfE-n#@0!W@+mNI zmdVhSXoJ8E;JRHA=Y%NKAC;fyimNJK%DG2-_kZ0!ML8X(J@Vz%gOFvf;3s^`<;zri ze{Yu5H{4Dy2uBX^LU?q04ns`mA_3sdR`Uj=;!7B4t@lP~35**@ zj@{A!!>*(`B;2~QT?p4W`e%iIx%}A`wom&u+XWw@DDBi_r0fH#MJ24`zJj@~gKr0b z-rZkwecWOI*sqU1s*ocMahUGL+e8i;|F@Xi&qi}1izD4_%bLR_SqpMa2v&ZstbS$7KBP}kKsqa|!?2e_i3 z6dKileFMlI$XX>$Qv@`L}q`3$%b~y@j-L}-6L{% zIR~<=5gDVwU-LPW1u0iu8DoPxH^nRiiO5z$`8iU3#I* zfr&N459Rp^{cI#%R0F#;6Nfv!+U=kB7k^blX0bVpwVV6w=7)-zq>?rE!0XvPw3^(%{8irZK#Cyn|FTcEmii z(xG*gQT#_J%VgXme?2zKve3{0cBco}lxV~5`D}L`O-hc8>{R;Ukt0!g;6M#ZNe=Leh@|D1TvH=Sl9hPqfdj%b-RXlgavT|SyXeI z>hDaMJsatUI(6Uku~@9Uc^BcNOEI+Z(NG6DAZJ2v z;!~rfH{h{V$*Ot69w5hIYE?0Ro0nzK$bewySigsx^%H%|v1a@*}hHrTR&>i`XGYJcZej0UVQ0wgxUVb#`s>Pru7*ad{Agi}B zK#{`f71`f=;F%7;_P*8N&{e|pBO1D@V3!G zY&}~4`Jm8v9wLCMaN~5)k*7lM=pRF>SWck>Oq|r$?czP?U7YuaCEucX1CQ+m zn)R`8>XB>+pNoe0O0&VbbG9&G%jg&eracFMS}gtBn>suI+_2L%opH1yb$SdR_VdQL zuJ}$k_~%W&5?n6n><7JpFFbcufG>0`z_W_tnk#K`cD9R06`K{-N z|FGG~p<5y?40nvY-YHu7$ZPLioFhGiw=L1rV|#-JOqDuc23XqO@V4LHNSgph2;kL{ zkh<>=|60cS_t5rmcjqk@+D$8zn*5Z6=$q5_XWsD-?x?~g{eOBn-TOP8X|V%4ZupKb z_qxHSJMjH!9rgB7*ir4^*@mYR70j)M06l2iv-;zg;M4s-y*&0r>TF%`RpO}F*s4%| z`zW}jriL=7H|wI4J^{*TfXY${Pt6B*)5A{ZMRy6Nmy=vPy?&~8qB~mt=&YB9s@Fb_ zlxoB9F<9T{%iQ-C_J=x@?io45dmN#Fg2s(fF;+@}-*RQUH9HOQktX%c0*vA2F6AnzG>@zF*yD~XwDL}9jCU2sghbMRm#@wyTITA_ixqbH z;lO@|uW;DSkA<=$%TQ_2I1_W4$JIf+D5)cHF|NaEgj1^?H#q!WxqF1%;SoM_g*@#$ zZqRBp@Lcp=MJ0DO!ULqc2Sn|ha;wKW_wX#1ukn*9Sg$S@$4gGQvKuXu3r>SOGbSOc z-<-|YfxvYwNlUk0ird;%N@ECTCa&Afy3&#bT>{Kxle0SGO(=O1Fc`;lI(<(l%5Hp^ z^vf#mVlh8;cRWx#zl5^}JiB)DLwwt%f5$^EbmP2ddPy$CBi=@m&AZkn?T)C6L=K!_ zkBJB-{a~o&+uN3yq5={qi{U0QP+gq|8JjQky;bioYvpHFw2f2p*iLF2!0;4GLRXtU zsO%t9BY}p0|EIm%+bfMf><#9`Zw$9SAo$6^n?v4Y#W$XA7GMK-yRzlp!OrD%SD@_2 zvtXUsVbLgo1~MkS-^)X^^s3*V0~=XC=YtfvG~G{k@99g$JuHRi4Y_<=-oKu8gba_t zKOv(P!sHepXt zeh6FyWCn`tpYz{RNPT4iD~(^eB@SRqfo4Gz#f^A_ES3-_oc$a`;Mt@BI9IJ$e=iZ3 z+um@Zs6=4E(NB$vHV8T2K`fKh%kcLFZR@{n>fq7`D=u%6W^K@ctDgu=JYy-ie@`&m zQ~8nbG>7YakQ=!j>qmF6-AcSWDJ!X0H_c=S=Nlv|=!-XY&}CU9S3PY$0@F7WOkP&^ zdGaN4RlyQrkTymy43e7?h*Wyel?aE&`Vx{krCy_VmC~%2?ImBs$HPCD?d~B{$Wg=_ z*W;#@%Q}EvBgT^3Zk15zoCm8#P*GFjQkt$r*Ssg=hX)ns-_hk!X*9NBeme1D&{#ei zoV1cgU0#`*9wq4m`#^M9nnB@c34 zjm?;ZnKB)ib4@4gq(ev(p$+HRjYyY`r^4yxJsO?ome<}Yx{N#WUS(zRNh}JFw}Ac6 zHJ0a;28R4;4&xkbW3{m~ydJh9|!iB@e+6!$t88!WUP>Fgc_ zx|>x6^)h|Lf`K4U7TU`o8ZdbpcL@8i?1jwl$USB!|svmNgLY)(`tzb0;KN?qY zCSa<)W#L7DzYe0F`8`TI&|%hkk^)ghNVGm(ye+%kJ_toXJ7Q}#IN9VejeUH3tF5@8 zKzNKQaS3p1NR7WZnc}!oMXyt*6IH1HRGc$znBy7R^4Kb#JF*6(ID$2Tsu%IMc{@OY zUj|0nH^=FJ{jSjCKi-~MrRt8uCfdtvGYX#|f0p06E z_9GDNbuQUz-j=tar{nD<+)M_3uwClXQF*PnO;iwe%L5mn1?eEY7drn=?HMuQ*6#64 zgr;^bF%OzY73nclO2pZ6gc&3P>5$$Q5Rq~~1GGVY!zL69md3LIV{NmjU!gadFJ;G` zV^#X(iR}5VM+0zD6@KG6VhnPC6C=fgLnbc-&1=FgQ(z@Lt|%wIQLlCUmRzVl)@ok4 zyyx!@Df508w$J4Rb)nPq>4Dk8$28cMg~xs)S2|`I~Fafd?*` zWx-?SIB4PPqYEsDn@neb7|006aGwJEIRwh8D$Wa16~n*VuV%x)-sP_E=tA+pS^VO0 z&sQ%5yIHKv&U+4&+-n;2uE@Udvyi8^0w?}i$G`*L+>}rTpE`$bFylOZ4LFe*vP1PBUgaedb3$`j zAYFhAK|-v%hH6X-W?H0k&P~zO*0W4l&^FcCj1g?M_G}y6pFn3-$BiDt6U2lYyigU zdB|)4&N`~LDQX4hjU(2FG(EahhQE}X`Jc?EBEt3PPaX(%`sdUR%8AfV=oA%gzP|C> zyyypLot^-Y07Q^T%0ehx$Gv4i0j}l87%4sX{{A(q-|q)QEz_c)*^ST$Q5$bb3Sxl- zcZ-+~ko27AO-Kfx1GbYq>tmRA;u%$j;jsUhPJ&O3I7Hk=!G^~HE_Fi?#ma6Ofc)w5v~-LHRlj~>viEQnJmOnvULzdrdu zm;Ks1Tb9>_xX4yIwN(;tYQZxqovuNFU#b{2y^hg8ORy)qDIt>{CDth(b4;b%5YKB& zzT}5xPSYM{Y|P3WhyEgr5~dOIIg0>EwhZ``4$J9qV%NLxqB-DuLL+vfADqt`DPHW z!>vil)7vS}4E@j)+W7Ps08SW18)-3cR1+4@e@vEBbJ zzwdZR1-_!-_@xA5wsW}#sbh^O=r8w2@V;S`%bEM&fxqAR5+Z8vo^<7BweH|Ttzjh# zmH~j(+xESd=$>Tu5Rq5LtYj8*#5PYdze4?3K!|RAGiuNrfKfKgm@!HYm)B37-g`wY zGrtu&(r@u5KaK0Zxn?OZ8rNseDlfVBorX%8uN5#FV!<>-iuD}rzpwf|d*Z>B^djSl z8#Br3jv*@r=H~PP8NG?!3FHn{dfLX>q4Dmeh0)#ZY$JO4rC5#?p;3Ddym7zr#9r1F zIb*+p=6I9LkZNdynXkJleBaxgDKJpKA_2%d>PEnSW0b(MaH6+6MYV~)3wb*lfMg!2 z(N=Y~F$Z|)Xy)JX5&pew=MnaZ09h#C#;-AAUXZ035P-hJfqLAIV87sH(1Xy z2>dqm5G>d{`$aLK*CPu89=NdWe` zA?G5t!g)C_>bz&wSwhGk)dBJ7|M!+(E|$Wtho&a<%pN7IQ-9! z!xbXgMiBwacHl3|HXc~E&s?A)ltZ1nc-`u^CzEV^IqXY)aEW~nvgb=9%YC%90smj_ zjBVe$CprR{cCH*)ztJ$GX8Y;P*1Q6&5*0kcsnIZ9|1?RBbCrkVvepo63^wns7s6JO z?qg5_fQ~A`TfmcZ4rRY82r&WyZ4bu}!}Pa!<2=@wGn>3|sMjDp_#xUO8nbBsbanlV zC}eo-&J;&F*1tS?ye&f%T9gs5I|#K5Kw$-a$$4XDoMR1*jR*xqE_-z8l!{~g)?+F- z>n>S&)qWkDYu8Go^%jR%vY|&(0Vx3+nn^V_2UtRmxZf$`>>IFy`}AYe4*0Ps_RRJs zfcn1br84k%>Iwkn30{Fh<-$S4WEZAYCQQy!e#q`1yUj@j*}`QG7}}EmGD_{zzUQsL zanqLAPeL;-=_M<4vk5TB)MmKbpQW6A(OTS!PJG{`%+s?+Q({eBMYu{zL4Ba{()j2{=>yHP35U^p_TVWnZcew^LG^Fbo9If9*utWmaA2{uI*e^rWn;Rj6 z(~G?$h%bZ84-pS>?=YlBAK=Mj(PciM@X%%a)?92U(SFAv33~hCxL9p)BO9{>d!M!} zHNEk;!DH_FL*LUrFi(yMt9uBvQd*>~a(#FCJ3erV6y7B$^JVVn1g&d>w}krUV$PxU z^yDAmdTIgX@wT>&1H_Jr_&XN@S;`C!Km*B^VjnF*e_#v$C!}AY<{d*>UD0C$=#%iv z@^b}Kh@J^1j^zFiZqAjuk_$$>4AZAwXFzn)H6M6Lp%{k5y zF2yfsM73&+&VvwzA)j2dYKNRH`G9M=H*n?b+Po-E7zYMLh$Q*49f(VcxZ2RMhZ!*M zoWnu90EJIee>$(L^d5`C-(3ax(}pkGHX3K2p~%7Jc7Fby z^UP3m=jJ&i|D301fF;$J2BVapOKAoTFCiGDIG=D^?VG^O zH~lP(E`fzQk17=vt5Oy$qrn~mQvqG=~$V9KXkeKKb>3*QH9p2tqS--u3m1Vg|@vfVt#cTh)(0kZ%F!Cv0u1EuaUFUS*K zmmRnKOX3iEhxnt<&MHQoHRavE(NSKdM{gx*khx2?j(>U|=@}N_otirFs*X{m42C{o zf-c!b8q=B_ay<-MSI{D%M%8!0GW|jL_I(lx7ee#Zx_@u=Za%5benK~A6Yo<2i6_^n zEB_inHFVmYwqSaHzH&x*uEPfEe})fZ`VZum2h=BJC{G<-=GRyqn)!z;&K80g>Nc0b zmhvEdU3q|bO8BKQ4NV(kot(-H?L6~JX9sQe^^rYp7;e0C@PKcN`_oLRTgMOB`zpfgK(er!3)qdOsH|6v`XK3c%X1Y;%={d4hfipJf|;LQDkjkf4>tkxbIn6{Dc1P?XqbCoaBsE z?N-dz=yRC0)>^yLdFN%0BYBl?j!u4^vycqYAbTtxx1Sq887EnKAvGM4R*MFHb>%3| zb);{wK0977&S)XHAqbRGUZZHETZA#Cpbh*ziUhFOb5EE$=r6;^OCP z@79R|*xn48Z~$P=j1n8Gf3bj3G7#bWhJnz~qGQ)=l%A%0aFORd1a%|3lS+Sfc3&o0 zFTnjP0k~hT+NK=t29RijecoFVPPB2#E-9SpW@F1 z8yA0MLcxxmIN16Xf^4#Bg7=aCs{@RzhFr@>ir|ZdNSMfF1o}G+ki%JEJa9OLiuVlg z5T|$0vZjX`%c>4fI)T%n2#j4R;>&jLxXMWVv|L`!_>6P2b2dCuo_pqs$C7G;xV<@O z!hTJtPW`g7&wmgTtrbN`CGaHMjry9*D!^)0$gKF9mXiE)OliS`tDw^L2cRrxa%sFs z9htu%*C%MFg$YM3$h%+~6a@t_vPKt5Vo1$@D`c(vcg$UA{-+8xIF7)TuSCC{#^oz$ zTyKGu?@)WfDXOK&lI&46CHs&YFie~d?CM3M#O6>KB^Kq{H%AvSQ7upHv9wDrxfbFZ zTlJh#GWes&&J+Cc^=t4^^x6H*Yj79o;Ux7w4jsyX=~YDCE(c(bs3@P}F{U=0f|GZr zDd3Mcbyl6$hl-!V{(~{4{!0Gu2D;0M4DFzr3ozKqzqPp}L||^j{=5KEuF{V*>O|(6 z;MLa_<#EeXaK^8dVPs{M0Odt8OaH<@f8fMYiF^bBk)1@IUy&lN9if(JzIW>v1Okk^ zpMpRDrYTacZNp+Po=3+s{6Q1wf!N}NK~%gPAJGbfjl5nDfyE0C{2HCug5J`j$h@hV zM<+*)ZDWby(lA7*b`(sH@ZUwUg=Ya9r+i8ygh8)dauwnhTvtoIwD*q`fy>_X$BPIf81dK` z1a^x32m2~=yzK9s@5r4i(BU!4J)U2U7-zpNxF4IN&c34~J#I`uEOqIhlM~B7_6vH# zK`n=W5JhXZI)kxvZNb-etuaU}Pe%j2f)Q{h>Kz6wSJJR0acK549@aAXA#6;Wg8dJj zxO+QSW(3p>#S1T8Zg0(%xSaS%TTWknT-@<9Kl@eYQFZJ>be;%^Fgg<$Dx{G?u|3I} z9!equs-i(#dbAIdER)4W#t;xBLTGOnsw{>*Z>KUGWu1jW_4qP3hK%Baoh zhm5jvHQ4>cs6%Deze^5MSEO`MGN@!H2fe%)SS6Ss#MSC5AQY*;U(h z^Mr|mu-s{LDxFOYCRa3?ac&?LeALZ~!41Y^5VAQZz%8LUnpXo=va>dsTF{ zTf?*vClPe0e%$YOda9n&y&Fh?=_VYgFDR2yqN`@D8dZO_rqJQX#8pF4xe!6YNxm-5 zgtNn=nq&(3tO*jD3&Aq~`i*mg`Z6sEIv}YS4D+1!(mOxBnjWbOJPCz5(9X469pq`o zeaE#+#=zjqwxo0T)}vEBzLFsv!Ikl{C0d+6xygo=o3pd<*d-Bsu=J!&oWmh72jNm<+CZq1d?nS;u~)bzRz2}ctT&UPq&NDpo{x=s zTf38w(9P$TKVQQ(!8~j+bk++;_{JemnW@^5_KyT?t;ZBvt9&1N@!Q`U?vm;IvDl(6 zT%bZHMpiZ1ov;6?zco#Jtx_}1WX1U_ud(w&E2}sjRQCCd<{i+M4l8$!BU+s}zMj<% zc0jVLFPcy!N5d=G_s+w2Tc)f&0=brx~`=J&X_nPQ39#ER^H z+9p>t>YH_A)1@>foRT_V*lQ{tvzAmF?3_2O{G}3xrv`blt?6w+{k9od*HyZ+$)x9# zuI}JZZ9@cT`LPfZwI~io>>z|}mm8qi@X-NHg7l#o^VYTbukKZWD9R$04S1P&91ZI) zbpSH8%v5(k+PRK*p(^NrnSvrm=Sya(P(DS7x?w}+$IwWMZB7V(%I^Gt7lh%$9IV%< zf3VqgoTK$LA!azhL_rd|GG<*=)xbxRD1s6n)_q_a)N00IIF$3W^OQ;cIjF6Ez;~$v z5n`k)iW&o2&xJVI1iy|aQ{TFD(_bN!5PD?_N-)dp0QY!Biwirh$OWV>KDc_aev=0! z^pMVTHiXeR?Wb+M?uy6*9dkCBz_E!2zm^1fDVX7o5}3oW-8^@Hfko5~lp3|1R{e_} zHJ=|r@MTkyvOF4&!|rC)+N1_fIY^C%whZj`n{Q012OaWni~Z8_D2+IwIHSOrUUK_b zMJHvOhI!e=T)!){URRjkm132YEKGW?OX+emi5F%1%Z;J)OBR1XuPz|Z>Pu%IBxWTo z-oPi>k>fG#+OG@cIXO95C4_kP63wqFN6gU*&I@rQ$Q~5XloC9x0(20K=bvd|fM*u> zn@r)!jjFkLcOThq4&#dHA~+h2#hV;uwobspEIjS>wCzY!h3hw)bVFGUW9$PyuetQekdI<2Mh^tDpYkLH zEh*~VryW6>5s@Bz7{T2q*5Ph|1;3uy0eAld-+%xO!Sp|Z=chM^HPU?`L(}bqJ)RMq z>dbfZhRDt#aVUW%s8v4_BWVrg8pE zP}3$J2Xj+rCumY@KKL+#(SX*Kt+TRF$$8H|0hXX!vw{hxw0}r2=#J9i_WlMJ4BNHU zm5^`1S$S!+#5u4`X&!N~`>cMJ!il2LwLnz|nApv&U~G0gZS$cBwDvt+PfCc@-sCdE zTZ_0N>l&c#`GR{Ar>G$?^QDxYxhRk%fP#p*B|5DTLHlEuTF#h@7IX{&YiqKBBaC4u z{(OAYF;S=w?zBn)bgK*8w%dOgLSX>}_6>7f+~5Poi6gx?iGo3obFi@Yz4;r-!NdKn znSdwp-v5b#u)7fEMhW*1B6>dziqqgyqX3G!{$LB;+R)MA1N7G0Z zTW~_Y*$Pg`KXr*9Yo&CwXh8j{&7bfN=)lINTa-;^Tssqrf@jg>4qOgjCL0jA0&gI% zIFunrXvcPd40u3VQ@yD2R}l(Akx-2u_v>RvNFJ-2mY$dd-C>K4lI+Jd&`k>KMH$q^ zLuZF%fKPPz;Y4FQ81bqP+z2e&^qC)gq(Py~kNo%jKq`AEAS9V_K-!!H0sQNIqex&5 zMb-}?MPL%G1KWLaK3EP%EZT*0j($@AZw*1N(9aD80sqY%#M+Hx=mf1T`{q@Zc_$R; zUpk#9Y^}ms0Or{Dxs5wcazwW!G$C)ueuPJzT^SUOc(h=e3PW-Kg&8m_@kixiOv)O} zSj2(tu`b7TOye*bx?Q$-FqX}?IUM$~4K|UF_P50*9M012) zYz#$PWhkoRa;`g`Ca23s0{_wO}e0Wg3$M_X) zkGmKo#7k(-Xi0p1Tj^J;@0!eg9dYg0_4@&Nc2nk9F>(0C(jnV@9!Q8Q+F6~N*9WIs z;_koX%rlgYDSdK>NBxggv!lm&zim8;OQX3gv)UkN`;BcM9s@7FJ07MU|=rL@c{(*?|W8(eQ2 zmqjyHKfTh{Oz4u81kbL_P6hJ7?l|>Q%iS8HkE7(0s1y2Z>U*jwi<$~@b}^(aJ|pk_ znjG|8(l1Ndl9ic@C2YLi7>1l!rX|7_(~M!|P{jlc%wNL-K37WMF+b0I>Ta2Mll4eN zn%p5Tae%#G+G`n%9IEV`j_vI2`hLzhVB z21WaGo!MV^g9DckGK~chV%6`qi~k;1qW z0tg$xOf3ggGYB!Bgmu<0T9mqbduS?M4qFoW3<3-Z56uYUEJe3^EKkl@ZfFAB%DG1IxfWD4GLWPZ8|;AO*Rx;yRJzI1puGY}wsF=_#)D3gHe*2xyXzhCZI z-0W44T2kB34EGaQA+NwHi#(eNP4E;^5_aIHgxRDU)s;%|bXjrZ!Uf%Hh*tx_S17Cl zD0x6m%wu&uym#wwhiV1D8jQ|JH!mc}q=8m0lh%7}W+7H^B`zP>J&`3jVD;NB+X;#9 zO&jehE7kkP=mI9>m;sEu3h0CT!J^&|CTV$QE24To!JXkQFtnHfwD1JTC+|@ZkBJ`` zHMfCeAyFA8X))Co9y#|;HIL4GdN3S!B*ojwQgH1&YicsnSmz^|xo2S|GahG2!F8d< zqK@3bMCPNI<~^#nBUnF|RB}~}HpdLBRtyG~05BrE&sTg3yN9p9rU3o7K9i5RFc}WS zI7QBJ`)1@wQFJl?CJh^bbB%{eks62N^oQ)llmFQEmA`JxmdtNvAYYi~47$~qQ|DB= z#QqcxM}uiv3{zsVZ2z3kwyhe-Q<8`|8vyIu5(vd5VcC6);}uxl1I#=aYmU$Wt5L~$ zPq;(S(82)x5#(D)3lHvs~Z}*}vEh)Vt+H>)ogPJQzs(IYkPgMSe1VQ%l4%VQ+=dZ^%CfDIaWOwF6 zEm06f@*jirR}e(VEN@9BI|Q~gli*-ycb~zvLV~CB_(Rs8(dA|tjAU|6#dR8ar_Krc^QzERkE%=s&Y@Qf5e9J%; zdR{r6D*S#uX6nGr)G??rdAJok8*X_9+BX+{+?Z@Eu%M$~o2`9emv`{B8CjaE7`u{J zV7|;hlHT}XTTF5s;m0{s*QKr$UR!+Ey+-2@Z~yo&fqt15p&l?K)ttTkTq)4vrwe1~ zy3{1L7zzzV$R~?fiI8Op;`=TqoNcJ*_;+kZtAFATiMn6F2M?O86@Kfk5zenajQ#rS^yoj;S>b1Pa_ zgYGGfd+GQi--~9T?HV|&HG0_^xBzWnQI5|`2K{`p8k|CeIV)jKi*)6~-3-&YfEig+ zKO)VSra!z@PvIz=j#N9A?EOX$|7vKQZr-99D&eKK1lZzD3^)K$^VvwN+5Ft=uW&P? z_^TJxf}9?R1=r5W2b&{A3-|=7=TIe2ZDjqfsUj7C>i>!9|gP zgIT8|yNX3j=MGF*egJ8LK!SrJ<=jqRmji0gzyC3(KnDssES<@|&6UlBc1p?T|6J3F zZ(xas>8L+A;c{JEh8gk5Gv_vXKO&O{IO9%jG+pKFSF`Ou!LvoIDq{WhO<7vbl1$lxXaVd~AxFY^)|VvwX?YhN0dzAU)X z`V-5M-Fq=U;}(SP1$ z={CLiSL9cLvH$X|hi{^1njy_S2(gVqvW8=$_lo*gUv8CC{@Mu$bO^8CG>n1-6TIoWeC(fODhkViaeo z4a#a;tfv3aZOgCrb&Z>G&u;D=RdvT|qLJnzKKYBj|8O}Z^z{LI3>t?vB4=b~0-L%@ z4ujQ4m@|K|{NFM5YBy|9TAIKw;?+qh*xHe`i$BTC&sS%ymRbyS6?WW<{|VI+byg^S z5nL+bkY4Z043eIqyunb%Ra>CMiE6`h#`eMR<4u7x&=0M7GAC!>sf+rXh|e0dn1w)2jwL1;-0BnDHT3&~;8RSPvzuoKAysQm*B}Paf}oc0YJjqkch#sg?e4jJ3O%Pj-f9`ddyb>{|~Yw1;ie%9wpPz4O2crgE?op|$AY z54iO3*?mfZCpH(PWA{=8pSq;+R6PH(Y0noM6Y)($HMm`#$6zUofMT@gpIGnBxd0Km zB$<#M3}$yjv-*ufodAD#CT1R^zx#qJQd|oq7x314JO?;b9L!e`4`C8=4Mg8J>Ev`{ zmy$s5DiEo%ZXss7#fX#D>(^lnoRFh`LLQ-xgEn=nE~VFhrw--uv-J`=zIVGs&Rb*& zg3qKeX*QstyoJDgF)uo@m}+VJFYx;4V+l#Ho%v6i4`l|rSe1L8YxwvNEsiu~Omz-s zi`hRSbrYJyRX#1lzJ6@wAce%qh@iDl`>~veo=vk1Veh+xDWB^jOK$oiOm|VHdg@QW z>z^AFHY%<|>^<91)Ny1rg7U;NKZrr+2*A>FN7?I&@mqc*HKiK&5pE6g~G1m*#^ym7_jtc{mKAwkV7_s6itn8X-w@_ zokfy!*?GUytF!?Tlu1^(W{m;pdh80{!DQz7)wAFk`}6#c=3WsFJ93o=N|{SM++3Wu zX3_7?p;@#@43DW5toJiVS4ypZ6gx#KDnjSR1{^u>>torsy6i?+>xGXfF*05QPkI%w z@?xX82Fx=`Fq%n7uF|nINLfW3d_`m6r5(i`!hwVVJjqK~TFsTiImg^jld)^O4^Z}H znUPDWk?|QJFd=v-Q<~l!0a%r4es<0H$7y9$$?)MfSIY5P<`G!9iK6^EH-4#>s z4t7$EfgUFU>Dxvr<}$)o?peH4!XkQ1@U?HtRIJdXfb6e|JYjd~eMBPYj~^7UEF3gb z2$78qFoq^#nBtWhG9=mHIrMJMhf)1ox9tUpRqC&6Z__!{DZ^lPz|YU{=6-c7e3U8} zp0iM0cmyY4aUm!m!fR%Q`9P2;TrJ>j`D@EmhU#aaeKpthVOCfE||G9}VoKayQ={hFA;bpQz1asu7*!s>j z**8_ZC4!M9qT@T2SR^%p4yCXh1-7xn#{5wdOAX|}TgaS<>Js=9 zB2rN&{d?pjF?dYzTi!(7Ez$~r%YO!bSi&dGa$Np|!OH=~ufLM(RN}so2ax8zDy>m_ z!e&iB+yn}lu@qm(jgQ<{eyCC$rQ7lX;DPoM$6^C~L$A=#G5MV?@H8-7p~Gr_?=3#J z@gULYz~@$wkGQ#a!u#TdW{4QJ3D9p)*(V#8UZa8JZ!3*YY*~Lr%hgk2dH=-#%EcD< z1!NUDJjiVM<-83`rK+I{u(5abc5b-$c=E^o%9J^6xp1H-OW|9bUq|wG;qikwa ze)Uy2kvRO*XkSa{BN0TvgiZOiQe%8<95+|22?`S|kK&Uz+KVy=IyAc*Ob~+}4dpO^ zhppThnF4+1AQMg-2C9#iyy^W0B=ANnZW6bPl(VJ_@AKPkm6}S`Ig^}za-?_!MMXC& zPjfF|(jpf|u#c&}Y%!kJOa2cyN==LX*>bZMVmf*U7{N#n07%ie-%}<9RgQSa!og4z z-Vb9uMGJ4$zdQ;ioEJbcOM)Zvj9=a4V9`YR%hi;tZ8{jZtK09L%HQ?-xBq#h2c ziVpGWF>b4;?y+dakbB3OIH-9q9DLUd5+%764JkY^D~t&XXXGtHl=Y6{Qvyp_W4sq# z_0u#<@YrdEK!aZ|h^Zm3NK_$%~2%mrci&9R^_8qGx?TmGV%KbG|B{57xD_sBAuEzxqMspRY89O<2>^tDnRj332Qp?8Dq0hMW4#2t;KnyU( zy7ybNop2$fS|J|o@3ih_1Hsng-}t#@SNUQcCq!mY!1(h7{anM|JTOmZ z0&5~^qKA;65(P9Dboxd1d#9LcJ@#~dns=s9^MsJOyP)zDp8M8F>G*sfi5CEDD2`ku zTmo+f@TW^w0hG-xYd-Pdh|!iJgv0JTOnow=jjHZZLzI}`qjr-WhZ!1h(YRwOc&JzO zD^}pbCNy>&n5U#ILoAulGF8Aqs>A3Lgtf6AtO67b2|U)><`C6KKv7P}&D?ax`6#dlWc{jp+ING)dld79hky^8&4rVOyG|5S}- zgX@39l6KCW?fq(gZFY=M;LRj0Z$G-mq6=by>d{cw^yOnO|J&^5f}K$lYH znYFBTCC8rn_@on#viLWc6GjE8XI|r|--OH2`u_X0JOZ;f<;d7;&z0J_pF9sU+92oW z_t@{>PkzlDXbZ6owGACusb(PU=HnzKFx`7n^vxxAH+zHBjm$l#VDdXlm{i#ExHE>S zl!YTfQ;)nu^o}pa#fEb9uN1A#t$iVB8glomvFpNhn7LlFg5mP}ab&bSP2Tn$18p0C zXGdVm&I73=&FrP|jrdrhUAqbS4qS4dCKBb|3^yeDB-F0gGGzV}l)-!qADwcSmzUjc zLcnK@Pb}`|SL3ll2;fnVkB^Va&c4bm7BWZmqcKUubv}3v91Qw^XXrOPY&g`OnQ#b& zMp}-@4NyxIK-QjXkTAvX!G>t54z1st;XbGn4 zhbuO3H*&s>4{!Z=Ktf!Zqv2PNKto^r9*}xz4L;&=^s})$f zhomeQSy_D%DHmMW@_C0&+22=FW1A?Pwskbx3OW>GgPWvH9pKt}1#=Vb*%6`ei%xt4 zv(N6_cQJA|f%B=mz@k3#Td0mtv9-0Gg!`}_`wbF1Z)CDutwUk~o=k|ef$%^Z@;7v?f|K{<0efMns~|5?A?+qOTIx`88z~et4!RWs-rbmZGJ&~5 zU~%w6CUaSfICVxZ&;xMOaG>Ej@)5^9_$Vc)>rCFm=jDAB{c2L@YWi-};#-nJGXmWt zI?i?WRRE!iL+m-R(RhB=H+AVcjpjf;4R3j=u3SrR5w5^+iJ;r0Ja%<4wbNg)ziUUsC(h}&a( zpSJjFdZAW)*}EVfR~T@0mA#u9Xid?IvdY_5LYyX-ieET)vGkUeoilqfYMPu;p zw3sqYZYyqE+23oSYJHToG2KRGgjoxiuY}xTCrYyx3XFofYtu0!#v{r7cq6RexOI`O z=_yX#)(=&FEEB8OujF=os!e>-R!m%uuJ^`}rs zsL&oK+S_iI2JC=GU2VO3WBnmO|Mig>{_Nyeg<8~#-7okz9+yU7l66vkOY>N>6a&T|uJB zdc9Y4c{!skc&)Y_!J`K+xhIZ)h6_z>zwzC2dZ$m|kgL-SzNS^+E6hkwPjA-a!l+x& z4cnBFoN)+@A~vxpuAT zOsfbTko;+>K3)eJE7^gjA&L=p9DOUF&(Bt@t6FC(N zTS=}E_DElR)yq`7slQl!qbA>^^_g68#?Z*x41!toPQYM`&aFhi+H zVD@sK@#NsIfxmH+Pm9^0FuKt$uxv!Z9l4F}F%)wKsDSy@4%78!Fc=6@)XiNqR141h zV9B7wfCDaWY^1u25y+Ax=oCY^P*B21{s75l7B0u`pxh3xeK*Fx8!ZS zx5s?Z?JYu}j6G}mK^zdfZ)Cq2haDfa+=>)%XIYYBrT9 zb*>UpK=a19`>(rRQpO}HDe730znuu}W=Qcf#9QH)yOhuT4(4u6$l|N%_B*yOZROfS zWo!I0F%K6B@pF|quE#HCeQuo;)r}g*y!=sfUyaYi^vv0irW~WMPD;{PtW0S%SK8>f zMG)xlWC`(MwcHDsNyE++>Z&$BOi=f)zzj>=w-kqgR$b+}RH!4JMoV~)Nz5WIbSGTw^-xuNqU6Ch8QD80V9x1#NK7%F+h+d(K0A!E@ujrEg?u$q6#KE)gV`j#& z`AVYKLQ2mqL9DqB*y`$rB)@Klap2QbX^HLMEeMUnVw_al=C1K0_Gv%azf{S&$uCe^ zHzkj8FiJ*f0LIG;3UmGWM3S>;6YGD&ob0LPe(kF*R{0(AkxoRP2BmX`ZN zF{AW;^XARIsmFh-6;THKU8)7JrmMpU+Si}?Bf3tX2uDDyeKh0!3pOUx;@$h-3;E_1 zfBE1c5J*0*Y*HYL6!Yt`WH4O7q`Lv%Sk7Xm4JC)5b&jODz%g9Um%7&L#@X|$)u|Ua zggVe7BL&i%eS0X_)mUPK&E8zR&fUmUQ%QoBGg_n!Pi>$>4D>aHD{DXrLsAIjyPl?K z12U|Qi}*t6scB#Gm>xCdu*dGVqUi4+$`Mj}W{GV6f@MjS%aL>s`Y~TH0+B|$re1%z zZe&?QngsB6V85ho8)RXgK{M<6^IUKD;RCN%O@9N&;HDv{>`BDa$TJ;mm0hA5#ih6q!qpAIkJc&k2! zB3~J;JnH))$84W(eSAV?^zm1x#$LcJJpwgzu^mC-Eheq{D=xp*p;g(s(Qe7dL!}zf*G$76r^zuY#YC+U$2N(kZ1i zXkV^mOirhSfkanh6dNKo_-g^;bhHS3$8K2G;VrMcUI4?_yu9ra_d@NdX%ta!d7a~- zU1&{gUQ;kQ=D`(yZ69V$la-ZC%i~AF&T131V>si3)%zdSy>~p- z|Nl2$qLh@Vq*96$8D(!OLPkZISy>r}>=06E8IcgjN=CBRu}4Lbl|7D8;T&5zA+z7( zNzvzfxvt;1e!uH>{oa4{KAqEfo!59iAM<{{KVTbj=|)zpZvD`)i$w+LJ6ByH7Pd|Qhc^a=5;wS4mS#See{iV7 z69E{YK-#nO620T2Zu?6(C>pkRmlQdf{yrvp31nI$_RAI)v&E6_@|vJ#nbJvY?7 zXjL#+G=AFd-D{8Q)sIBV|9B`GDW+z!DnF_~CF^@5elxU;ef&W^;!m;4BZLyirxmRs zHb%*me{zmWPS5PF%}Mq9e@q6kb%L?L#(Wd!RyLhU(Jj{c+2H`=YWRi_Y7aKMgGo9j z7(@?RlL(-IS5q=*`h{QZ0jpm~LeJw4&%dkuzG7MRJ{6LYxrL?AMOzKQ-%M6^Rt6OyXa< zen(18AYPAp1V|-f+^S?{?cRY)Hp+*!z%jw}W@YTW)DFtnu~3Il+_GwNsXe5y5wlBx z#zZjcS_|sBZoml;94=o57<1XCn0qNl&ON9^^c^tlNC8!|-nyDI<`DbG?sWH_%Gg-4Zn$q-+yfFYWz3RQIy&yOs zWBOLc8kmsu5o`%(!~wS)F>zZKG9$6sA5T3veyAn46oL3a5Nt(tiG}ay=x&I@khlQ^ zWeUGNS=Eq`>Q;=w3!pq7jh8y@DMG#5{LZD3H@qX3f=>MHY^vYV_F^k%Bjnv9fkPu^ z%Flo?44R$osdoAnr>2b+Q0s_zDC=6Ur!gp=M##3ws+}c zF~Rn}`M$vvwzBP40XYcvmbq<5t9WvsoVFmohxBUU@~cmZIrQAqChaUqwmJCs?6M_H+hVUDluxhI zuE&$(v!$^MI75B>f`FkgwrJP{8>Jt?&RRM?1yC)gRDk+yYRI1#0^h*;rjaBN# zOm+rGJXvtrf$OKRoEM*^bs~Px8tu-ZY)jg6l0E!)%>Y*R+Wn+~Yyls##C@_>U4uGS zA9qYR^Cyfgo)Ivl#TIEfNs{I@p%0Dm;O@kA$KQtYB{;dK+=CIv|)%WWms{aP?}OQi}yitLOPu*F(C$SfX=VhD}IqIS(e)u3ULxH#5{HPICjg z#2)WDr_*uT+zqF4yX&8x7lmB)3^1-5_>?xY7x-}^akrD?)O$@w!?tT!fkei0wg)VQ(lYLArR1I;-1~tzeLrW+ z0F$dr1kG$*ZyL2zY5#j|(0REv9~o?_##9rPla3Ql9&5d*29rBFqBtBB-WR8CnP}wp zud?uqlWXdYdX9|2>h}8kO6n{qPSRdKx@(=!HjoX`Giijl8(vs;y8rbLrXwJkyY+Mc z;3gJ%Q<2rcUeur&jYGt>z>or zt#IH+kEQ=!!GBy;tRCnI&-t=#T=$&r11IU&fB#njKezbdf3w^4w&4B?x%>ZdjeqW_ z)Eot6Yh3RO+O=%^pI35na3jttvaD2p{^Rz7^Ie38-e|43b=`Z&f~Z+XSZR-}w_yC; z;h_%+#*?r6#^PlSUP&*NCX)Swa5gTHwXvFigCKzq zV+uv|V6JD1QnX#`lZ3T3bUit>UATK$t=WgInjFDK%@Asf3(e|G1ng`r%tyC9m0qZ| zRu{?NCB;-t@g^xLtLVX%ou18G2Zd}G=gdtxh8DxYw&62gjNSSG2F)=E7d4CNUZ}C= z-|WzZArS|heO4w{6KHUd#+L?E4)-(EoBqTjb9M35*Wj=g%mAq`(Icj*m}ffI}% zlVpD|(X&7RQ}mQs;Q}pJs_t35O|(*n76UdTyjDoL8kbV2G>bE)r?;Rgvi9;?f{P%L zibEikw%AaXk}<}zbw*6L1iv{1^UxORZrL5RhIs37=c3!W)IP3kydqD13zd2XMSlJ0 zwCLU`S;P_7k`sD3u<^;wzbW$C!bGp$alxV=Ro^hSoU{GsXEc@jVbNurhu$sZNFwh_fqT~I}{{lu0~DKS6Fj0o3Iagv9s zazAJq1VS+UL3{*?B*suwMiiHF@w(MqV}obcUj2ii_(g@0mD7{+lJdS1&Z0>C1++L$ zSoO4zrOmNAu*6aQ87LH?X$WTt!}4|lu>1yQzuU{V1z4*BdC<⁡vH}osZy}P08mh z6XMC84{2SoB>wChOZm0Q+Ew^{BtxdcqZ~L<51699-#>lCIVCbEVe=#zEE7m-qCS)FR|>(yNp`qEcME^ ztIx;t+8(&ILYm|D5u*GK~WHX~H*&cXCC>pGXuXSRe+0IH7O z&At1g`y?F+8caV3#p8CReAEY4{iM__WPx%g%!F6oY9Omg_HLDXI|8+sdh##OqG1Jw zzVf!Zypi3;#>O`FS1(`wIyyl&vrqJI0t`(SD2($6xydN4yd{@_&gFF>6c}bfT!G zn8Nj$$+f~N(+M(nKW2Ka?N4Z@R;pV}VsH%#XY|z7)R2mi5tv-0Lx9B0WatC*S2I*- zah!&**>?!=h3KR_Z*HSW&KjLTs|yg^gEdy#M^UC$1hc|S2GpyYo-wX&W~PsD+X%O| zG=Z2R+WKPfRT5=F2TJ9kitY~`k&6J9wIp-cpKka+_(ceLNbSFY;Ss2+EG)B^LxTyD zpoFc>L!9Zpn1Zv(RsnLc8AdQ$_U|uO8o((Fcjfa@ublI4N;txuJhypA;j!!{BWMw* zT)V~r;)xf|o-EzFXHRG&G}6SLU)Z~L+oPevzTIV~ipR=Zjoe_zz4-koM!p)h+3P_El>*jlE!ro3` zW>$shxtle@et4ydyT)!c*GTFMJYvK1@;OZt|WJY-|B<-k5{B$Wx*S&U`>PDcfoc zQcX*P{6jX?!9CTvRoEv5`0Q@hwANy0y9ukXtWc#FrSTt9Z{iPY;%3LpFIth)RN8t% zL4-w8Qjh)Q$*iLID?2y3Ze)~A)zjm2c7{~*KY|F0`q|Ll*4p|7HN#6EAxYJZoT-OY zle0yjbS7+G8`q1;wV`s+l}rV#>5DN^As`}8JW#7vQL{`2P@*W zqa=E^+)PZnYErka6=xdTje8Z#r$feI<{CD9sK4pkq5d8!lWuGQf$tJ=$%@*9Dv_V; zF;!lhNMB`nEv5K8n9V_vH__Pua3c+1_nA7c2 z^&o%D-QCD$MpZK0SkR#sDqo$+n+u`9dy}kT_|n!>eu@{J1gGp+#pC8YS|JY|c;DqH znI_h?*u2xU!&k>H@z6QiWU62MYAPy6N7~dyQwIC;kELii*RfYr^pMQ!-kmL6z~Ei# zdoyLqT}^7CrShk9GU>w2FKSB|=xBK_P^vTpEmlHqVj)IXa^z5d=3;UoCTa^`>5g}k zBvM>vCQmeHbSEVXm;V!+*xILTcpH5fdzQSmA$_KavlUA;X3EU{n`4P#3lP?yt)_C1N>-Fx-}eTj;;&b*JvN zcByXtmGN0%&BECyL=MANI4IW6<)Ilmv(P;~*9xa^Hh(xjMY}kp)U9x)D6FYoKWdD! zUZAvgu$48fU>=sX(uj3d$8+mI^%2VvE7g)HV+%ZHc$`?jHKnlDq+vXB<0i`stlVUt zv~m?fjjT5u%6Bf-U&+mPkq(%Gx-Pf{WEv%BtuYXm*O7lk z-cu8lv9CeV6!A*Ykf!A*9pS4{j7UL2fmExUp$moBJd=oP4?w1?+T#FO;w}Z%WQ*Qh zaOY6OVm<8kGhj9hG^l(IOO{Lvaj&fkRIKNDv3NdZWyw&6N!>7M&1bZmYEf4}bSX45 z6oo}dy(w7R>h!}2XdBVpDjZ2IM1Y)B@e&|5@T#W1d2@kJh-d}RH~@`*FIogyeJ-Sd zB^|-ckC>qNYWoDDN5B{G+J9q=G{AdbyF<)4GJkBU4 zguop9WwequH9c0YZU?fCSLFr;$(uZqFLAJ_@nh++>W!5CCM5sLYSzz=uedNwbIG?% zz=O(rYXuZTTk02ZRycJpC6HfG{R-9s6NP^cDo9lmU@s1*<$|$L92ywt@kr}bi(2S7 z4t@rqzcf^3OJ<7WRruT&qVMgO>}hUTTEt#x_-=#!krkO^M|@BpQ5mp zq_0xc5JSZ8&6=26yz4olB9`tg9K}Br^3cD zTkYG;quX$m;x|60zNmfYmB?^TYbmWG=9q+Wi)l**jNN8O?f1lh&0l{8x3`ahv|aFR zq97iyQOuw=mLZ&6pV_2jI*7^HTJ$38`_5I*CM5tdWeuU{rPYed>i4@-u{%)z=xrGF z;;a1w`m~{ob0JOd-o0}%3lo0Xy`wHF9k|)QQRY_rgV|5$9S7@6Dz)vyvi!Y5#14$;Uph zG5VgiNo%E~B~?IimOfd5fAG;RW@-Ac7d*yKe1IK?IMD@-qku{T;fNz>2xMG=T-{3B z1bKA-P0*UNi>5RCQ zgN`U@-#Rr>u>Sn|_Mc?-zxg^iB8>o+o+wZyR?*Nnkybn%#-Wv6S(6IWf3rqG@e_ev zssN(M4Zz9^-+WkPR#PO(d*Dz`{p7){_?R+~awMr2Kgvz+|Fqc?Vh-ouZYlEl#NnSX zFDxnC}$GHgf%LP3a*C{wOMEy?9Nv$>>@3u^T4NGsT#t zy?Z@grB@|I1CGogtC-316U5WooUoIa`L1|NOresyq=)jaTg5@t`}8kxbI&Zfw&3Ra zWtmh_lYu~VO&gmjAwmoyHpo=CusH8pZ_Vr8sv_z%t&1a!etzrf4ndBJcFrXH)sjoJ zII3*gDX?}zCM8E^=~2J4b#7o2Zb;2*i%?byHfq*sH2q#~N}V?L*Zxg&*k{~^8` z>Nl@{eb?-Bg!6ZsWdm*IihKF5rCBLP@Usp=w^lp(|7I+AJiwhVWF;7_uBd;#@i63^ zN$=ktSogyPIwlQbl=!hVKljgpFUj6%(I2@_?hDx-y-b z%DFJWaA3zAOS2>PvI;vk;GNyE=JDOZAxa6`?C^`DTUNqpB^|r{ukAc?AE-W3dnjL4 z+{z#uk88uu2?usIW)dwK%$$Okc1kc~>LXL@!opa-jn5qGbDGs`(iBz+_>h`^<>`@7 z%@AgKC6TbpYHB9R+Pb=teodxR!tYl)NFPuKX(#h?2gx!u|K_x4N?noE3rUNWW>Hk? zsa$cAq>f|Aqj@VTUE}jifo%utldBb@RQ`L(Bu%o#vq$9XWgnDA=OKPH31_4DcR~EM zRi3QSVBnXRCR=x9K;leB7iy!d&GxrA9`L}pXutLl3x@jBGrET>?J7rTS4ABbUA=P% zW|s6eRWpk~+_yGqRF_PW%uM~FB>&D`62Y*N%vv0nRf05O@>(BP%c(p|ENS+Ea~ zQgjk>8T&h0dbn~?x6?eu!h9jA5YyqvPoI%NMUr}3uOCtCaUGLJFG-5Jbv-oOa>$hW z!~xG5Xa-Ee#PAP3KP3E;-%RTE`Z1!~{aEKLWI#Ya!hC!w&{w%|yRZP^SMO%Se)D6q zbQkRO@yh-PtpaUa_a4(W$I=liyV_ad9nBJ1`8|uHcKFOPnu>QAuC*a_e}U2 z4*TyCSy?ta!Mun?2duw{?7_S!JY*-Xr3h;oid5OX4R1wu-Y{A(eUV~fgyS~8{3N^< zDE#xg-(d=9-KPDT|Gc$!V7Z6jaqM5)u2s7KUEpGr{kPMtKo|+(8S=kAIAHYa`&U?5 zTdJz7dG&xHX7S7tCZ$bcdl$~bbZc%eFE3Pv+q7i|5Fg^@SZK|7B>$=F{MT^1K*~E> zFgkIrK8#-h@McX{GV?y7(d&TdPJSK|g6%_5K41=g?6@ta5Nn|bu0ya=Jw^wd`Vm8K@sb~53^$)wK5c7&SKnh&WD?48+- z>C6s5*u;ZaaaYGKzJFFMWfWZZ_bFk{KZ$7&EUs@ihS}4`FJwTZrXRk$&~XZ4u%chz z-|sx2_d34x8Uw1|y@Z=jK|}@jmM>A{ z5@aMZSN81LQ#Y39_U(}2@xM3bj+6<-T>(Fi6Ho)~IC;NuW$E6~fyU zflh$L)Ev<{7Q9jtUO9~q0oOMDG2r?qp*b#e*Jb{0;m}chhqro%e8U z^)UGxct-m&OU|;qm)iX|gjL!G<6L)KHZU+4>w%24wZzZu>mhID!`KHiYwqFeu1AOH z3k)l6r8qs6Nq8FK=l16_8VQIoZ03qyzm@~Im4)B!*2--DnRbv1U?g&kt~Ht|5%?^9 zN@(oQ3Y$WA`o(Zcn%mr{OZzz7t-jcH7>h^IsB-jkUbKP zYcU1_BpRj9KmCJ6^rwbn?H}}ddG28tb6UxqdDrg!k3dQ&QrNM}=3|Sz>a}Z@J--1? zpP1F__Nt!8Y>>g;YX9Fb7M3HqDje_Bysy?uK;zH3e7whqv2YM&<~8BMS-Njf3sJ@*bJzLrzXQAt~>2Uf+$}HNprX1>;DPUzo1gQ`{Bkv4s&5QrA+!b zC?(O(SRf2n4}cT_xOGcONog$O+zw%B$iGc{E`|$TMv5sU7tjRe0UUikG$?B4;cHqT zPj>}-Q4is#;il0eHqM|v1{xrQJPEm9Xha4YVFakkry|b(M1KBmWaIl(N_qWiN|s)p zUwL@_F=eR_lLoF=$8-o$ICLcJXag>KFqAO>x2fkKJJ%+P2)kki7n)nHy88(65A zo!g<|mva#0CQKf50d1>XNAW0b zdolks3>@C&W+GqcIHq_r)#LZlBs^SAGy$-jmAKqyU$HSepaxp=vM%;d8Q(~Qsa|uG z>DgBvHZCtRs1Hwem>wlYYLCw*_qvY=C)rV@L!LnT-TLkO?Pt$u>op@f{QlNK>!g+H zwUg8k7X>$qr4XvIR&~M_Z9D9cP@J$Mo2?gTd2x^5CnHrS;sndfZcC4I_;!BlL(0>1 z7dZlhIla81-*qg$w8-&;y(b+`#5xQ zp6+Q^q*&{mIAfL&6l$pu;_^FUzmC*Z9! zpBXlSZ7j3d=2c_0_79}KLm)^|y3=gqdLR0MgRc(oi)C1Q=xeXu((CiL-^39Y#T3JTk0u8ZIp1S?NMo#}ma9Lc5TBxw$N_*idY9A0(_muAL>MKka zbR+wl6vU_P4QKEirMa0l*=VVx>#bNey|&!sxg01b3loLKu~4Y142G)&1G5y~}-pxgp3) zcz-t;J3Y>uz9)7#MbgV>p@O8`(e0U&I4bUuit7k zOkrvt46ZIq%YZ%GQIwAiB=j-(`6II~kO;7UmGKLlTDQ1+aib=|Bo)$RQCoL?W=^Ys zLvGcGj-``wX?3qoU`@#rX&&uEabvWUS+18PEl#RkkprN>6*VAc>Aw#jGl)QHK!mG0~ zGLFXddpidpqlJP&k(POqT2FOrxWJ|wNN+#kwthCR64)eN7&ec^v!0l+t{qdzMU@F` zQRifY=oo%tYP&#`9VhSWs2qd$pUS6+KI~Y|g$)~6dF7?gTv4~1JmPPLnteJaZHMBZ z*{65=YjNvXCrU`DwvA5Wea-=4O>~FDQv4U^tg0lhPQL-m4;k)$2C5B?^$`T4?ZXpS zu>yt8_qrUh8PQ6*Ii#lA`piZ1`PY|?$B$v3O$6KBQ&HKXyl&@g1DPMh_*|RTO(HX; zQUAi8yyc(8v|9O*x}(q7wq}L?eWi)?I~7Fknyr zW&V#32Woxtqzef^1LyNBfSuA%)&-HfG3dBVg0_S_taXTTn!+*vU^Pm88-$NH!~A40O!@<`-xwEB>#f1=GJM=&$KVK7={2?VWBvbdk0? zYd0wH@1Hm|zxEt&@7CWw#4I7R64=pP_q;?tQZzjgE^MdRYViI+g}4kTEcd@K^gI0# zB#(8lgygI`p0p^?xN0!%;m;Nasw%-BTk_I=bgdgj>9@hK^2^UX{C^_Z1E?VK`DDST zq(YTZz)`eIu?N8DpfE$2Yhr*21EjoM2$xQ5K|v>n#AR2e;R*=bjA2}eiAgsg=yy5X z4_Wh!XZGlCXK0Rm)wA-J(Nn1tk(xv=ZPRiLN-A2)w9jE(ErF7hBadJ^eIK6+Lb+bs zY--jq11p+j1p}}_^D4Lf{Zvm#P*4j#L6n`HJvKJ>Klsu#mB~^#-y=CwzP`-f-VCxP z)7xJ1t*pnez4|6Po6cKBEAJc`OQps@$Vc!`r9DubpNp%5C0YX2^}%z5Dd^UiloSpP zAJgBm>iwm@kncOUIko!rsg<`hk3&wfxj@?5;ASp(uo>fby<1>R5^$7DZ-X4UiA)C! zST7Qo0KPKwN%hp{K*7GC>U}9M^iy?pcZfsDh~5Pl?p??A53u7Gi$BIt&iIDqK?VoA zM;EW}UP&f5kYyl9)1KMV8OX|WWI?e|YhGm9%auc<7)t*RoEBB=YD)kn3 zbgkXSnoQd^IXMipJ`ix)qv3S5eb64HpLlluY{4JtGHnKem%9DwZPes``}x^Lgwxgv z9%Ni)`*^-E%1y1jRY?XhoUi>mkc+ErP}g(gw|WEY`Ov=!P=!)oU1ZZ$0|z^6Pey=T z<$0@?*H*3-K|H0lflV^Xy==B}OEOXUk6Gud1h-K~rZLDi)VQ!6zV&5eN4mj1O-@G` zt}x{{;>TMjbPg>+B%y=0KSKWn)=__}s4f5J&}Hb*mTi}9{v7%&6W`kcNYO5IrGa8$ z+c)*AS6jy0`sNgcokaWn&l$mpWx^OxP7v(xRGQOx7aJ600)ZXZT2u;3$Z=rLtZ*W| zh0v3M(F+3e%hKyI|swF!Xf>P{zi`J_f@h;GRd>fOlI9bZS<^? zC5_p?hCd|H#>iWQzh7yP!U~3qA?c9V^z=)J;|~RMD&MO9s=xsVFN5fmh0Z`;whR~t zUa~ffo@RmhWT(^k{X>$pp{zdeCMoqzn^^H)TN^aGwd+=&uMJkfx`sI#66ov= zy6A3Uw7*pBHVO#fc#5LiJFpTKr&R7cA7sFtj@1Q?f~NU=#c9*6hLu03W%t!Zj`An4 znjMxs-%j(0_9ofx03_}RxF9+h?xtLxzZu5cyxke2(m!4eqZleMU1&ID-%Xga{6RdD zN)!m0y^q{8#%P7WyembCmi{`DK_f>XR)N^iQASZm9}vc6N2Ki5;5Sj#;Df@JZb3X` zF+(Er)Ig5|^lGpc=I`!r%}DqP)hrE47LsuG3D%QY4;&6w?aN#ms08H<`ApYmvzqVr z@o4Haex#ElgjF{SJ;Uq2r(so&@dZY95XeY|;_p4H45pKJ5V%-`0oQ6W18~o8j3#=1 zTVaZ3c}>hH$V|W(NuG~dAsqx35=lBC(gSeZF}a6#>^l{WinpYO|i%U|8KfiwZLJpJ#;@QJ?dDf1(&K?*Jdo6*{7|^ zcX}QqU?oIPhU9{_>iMPV+|sIqD0g zNl*(6`(301I-A0nEURvj*{?f2G!lHpU4Ht*hPebo->#W@q?XuN8Jyl&4fx6`$UXd% zYHL0@a_W_IHCsS9kfvS=<4+#xdLtJ9)g(1}K$8u}8EhfbCD(PEy4sL7hynD-vcNMe z;20+ok`I!TWAa%(SuW}i0bAXqRU#jqHu@wlY;c!s=#VUM{q8cJ$^1?^$fTftr)AFx z{!0nGZb-=4CNd!?ixOaBsVHrE{|8${b^oC#qBz-08r6xU&nvWfA{Ma@do zyJ0y=gR@~Vtn5|?qJ8(?(AyRJcmg}0$B~8Zm!(bab2-O!UTTwC-MQmRUqGfcl84uM zhj2d)ijNFc2zA;zd#0cc9aKEsS-tSRA&BPp&jd-N6v8~GepY9f;C5{4ToQIOdq<{% zlw@GRz`%g9dc|ucCJOtC39NNVNAd<;=&;Sjfb(vS^fp^bKw!nA*P6=fbTv8&YY1ECv4^*RRP zy~XQX$yJU-=Uyw!EqnNtsLNg>1ldnHJ>|R*kD{SmG1)%Z55Er@T~8AgFxnov74zwa zB2x=mb-reH+c(OVX+G11VbK1H={5wg20n+;&+Fren zhwjt!fY(y&9|M58dG)x?YYO`~`tEl)o?#HbFz$>P+4u~y0%w>zAVVMg`Ko>g4|jB~ z!N@cWz-c*o%~S{jObkDf;S*|iQxe|$U7;$1Dm>wN{*oPci&Z#i2<+cW*(b!qVJJ7| z@F9R&Z$4BAjL(1hO=T($ZCs7mD=u6BKR>@M8cGeqdIn?^uA>>Ume0yc+x7^YJb7{? zrw@s2LHgG9n^%?#BpR$5b7Xqd|Sgj;ImInPL7&Mz8~#Oh9u6JHW;8As|vxgsrM_AI~wL zaF(r}ioVR+ESHZzdm!}%Fq^v;8trw2FmfgGk_64Nvi@Pv80D3>KhdGsvAUD~99 z##Py7j#<6x+Crvh-yk#2RgxgZ6dH9L6_a`6H>Ivaq#s-YebIzyweX|r6wmT|Qj4dx zNH;bWxBe6kD&W-0mgCb%z;otyU0;$i>_KTJ3x|IV@zmGA%}fRw9|L@jIO&9gLH>P2 ziGVy>^ZdgfK2_K37!JsaB)v*Ma_Quk$AY`u zb^H>&jg^zHkmN6B7&Gzc_2aBEX)gX;Wo<%T!~q*J=3PYRox6w|Q+5&a^>-aVaX@(K zS06LvWxhgLR3owVIf8YYJZq=M7aod#6Fb1!7DlviEm{zBCe;cu!?4b$LP>V#0T|9) z_o+Mp)Ox&2e&ig#clj1i8#b%a0^w|TEk?8gl2r~co=5AFJ1g^S`i|CN+TPwcl_SkW zMe%I*pozhhxH?-hv=g5mnYwFmaA?-B>Bf0(Nripq1TK7f%S!$%pTNWKSJNc&@U-m( z>{{NumebHw3T%wj48e$>b^GVd)dc4(f>&qjTztRYcQchBsKiZCPOTR{QzFvXXGd}B zwR+L&QPGcrkVGqU^SbcrY>1xE;-~o5g;(QI8D9J8rNir`Z=V6oM}=u?y1CwmmH^iJ z|MWtca>)<53?2cP1l?lf6Aao#x1lg7Uophv=m)y&FOUkE_*b5+;uQ63HeR5jaR^P^ zIo8EWyBf1U23?1)V--R?4JRMiQ?By0J*B8_Y)~w<#lTE}XJ%~>=i0@aA&&+^ZRqeoAqkA5N?6*c%Ja9z<$nj7*yR2^&+pq9J&0yw z9e_usqWozq5OL|rr(r&w@RvKAVgUf@X;z0+%wUuAK(n~4S-9be@{3bvpjdeX!d`hk zs-e|orw8<^xpzJKFMXH~#D109i&!Jy>Ry^tl3)G;Lr{WHSKHJF2v7GA z4>%&2t^m;9(w^@1@Ia+05H7o*sTsNeqYLRHo*cHGN3IekSql=|#AF#Ao}*#BKb{HX zr2(VFs$huK?%K$3Ps}d2LhB-8-VkYo&-ajazHMB7i#@V+jv!Blcy*x7P6$Wf?rUJX zVm9&f3yr3~myX&&z+r9Yji2YuG^taKFg$5k?Fj(U^f)QY9ht8Q-Ou&3=#hT3|2^0or2lLl~DR0 zfE5`GY*Pi88YpwmX}^8r%wwzl+Z(06!UtS)Ap+EUk7b!jUAQN^|50;UM2QEiGh>Jk zCz0Zc{P74QGHWR9hJ5;Tp#iU&Hh%nG1jgB@grODw~Qjvi9>|LL_0GU@jeYNV)u7&HutIKw$=nPgV)>Y>}ws zJK{iAx19p9Wuwn*BqsPO8=8O5bNvExF?>Ve+NYR7f5yQHJAo#$2s^^69P29yOI6`^&_~SqIF&Q< zl|0aVX2cr>ln3q9IDL^weLSPpU65T5HB|Qh+-+P^=#aNCSINuD#KW6riFy_7-R*e%)Z{zAkU+2olPl&czl}d_k8JZ3zOQRnu2!xDF>5? z9h9>Z)hkpq!#NU>jYF0%WKn~M!B;UI7 zCQZ)1cvthQYDJr~hL^gDIe zYrFU|db_>o?LS-_orV<|#^A2?bmOMY&BDnili^}Z>Lg~Aik2EeospcHvwL`5u)Sch z?-Ym;2O+v9lHTnm6mo5UiKQ#&>wWe6ak1LZ7j)l~`-t)CC%e0F1@o)=UT=Z>Ed17< z38J`ww@WOl^H;WwmK^2`G99=E>W1taO<5x}y|N+G{Z`}t1n)>q7#x1nfbCIB^yDng zu?11rOfW)&??_YgO-QW?*~nzRArox!#1gHXQkk^bWHnJT68V&=;F^ zLZ+g?_hrz^MIjki3Ib+?;EPTkKb@$JD#BzAKXp$bHY#n>$-qMo!`iWe>V|o=GM=tp z`QCM>E}Mt>UYdF!qMA~d_-bBu=O89)zftp-aETG+Tqt;d>N=X7yM1IV8#K`Rhs&;6 z8ytcN_n^nZ4}$ZpA7uBf4!PAuTDd6yC6HHA16=MXS_6a+388ASsL{^OFHBE%WRFIv zd$IFisRT`%PTotlCvV=cXMMkGHxhRcR5`zO&9U$IX<)eFqPR(Tu1&_NiP$j|eiSPj&g;Ql}2t{o2n9hzrEVmk&alh|~I6-=!_zck92zx3>2c{fUTB3)HWRU%5Hb5Lts;QTpcLTo4xx zJsT*Zyf0VAs4>VMKhBXFAC^Y7<3_adCSI45CO_l7`EI0hx!6_zS|HWGa%wAS#dF`w z&L6XhtUY>BeCoMFi>Sk}QQgZkW$hgfR+?4qV61ga2vMZ%455P@F}sR~E@1X2d*=>~ z7J!zP{LhVVp$t`cLvrLv-hLn7Zu<;>sO^O5>pDB$l4p)bzgsGOV9 z@bQnV?*_8ByKjDE0}09r$u=Jw7T~0&?N`~XbjP%nBIV<2ZZgdimw?LvSg~l$I8}8? zC8MUR*P!zf;!{UA5v<5qHWYvIg^2eh>EnWFkVR-)>luo6;Whdqk}_^5ia3I7VRLHm z0r9@Nk=gD&Iw3IEV`4Y*avk*bS53QRBo*pw!!3X3%_>XG8fAS;**uHt0a@mw5G4;& z&~*Qt)3FYa7=Hb#Tr>6@n!JELHy2*qUW>u@;8xin`9}bKHS2D4Fdy$@(+_y@fnW$) zm-*6ntjX$zBflm0u+*`~69ih@pA2&DXviA357^v&{hsBVWdV!^tUc(xlmHg`Tk>WL zlfI#M+Z(W(9lue;#59$cv{G=}lPogGiaDo|A|&G2f#hhPOZ~m3f}so~Z>T4bKg2t4 zZ>RQ|9}GYPg^)e2oB65O(K_>thKDuD=Vt^u1~*YpJ>{BvnKiJ#)85nHo2x($8jnjIW~zmc!|v`(=!l=qfv)}CdE{F++5PnO3|as0^< z2O5ezYJo?^=OcD~Kf^SwRyoMN0KV>&Cqr@2WxbKCPd_`1kqUsO*A;6gH>8bg{t^ik zDykK8s?dkz!Gu1N=(Zvci@o%`ZB(FZ#i&D9Pr)8t?>ZeqmjRJGnx=1UH?-n68CchF#sVbz*J(v1a< zW9DQ!=BGWDZj%&q!8r?P;s~5vTmtppg;TBCBe89z8fcW0{pBrgOX2QqR}CG4V{kh6 z^g1OQ@W@*?bl5;;tsr#q^;BcFi+57pbMH1ut0gL~yt0~SO%x|e5M)up?rDCn!;x)c z=csvxv&OUiP{j(e^oE~an-^jPsYe%qwDr8tPc-!{F<&}?$p0t!HF^rMZ}SWVvd4*~ zZ-flhwo0RZljnKz@WUDb@g?D!b~&au?q0jH7iaS-XgujJs;> zVD!xS@C|z9TRB3-=ENInHJh2fkuMyr<83j=4-#53bGo(Yoh*`d!_GH}o=h#M8o7=1 z-eEKNxjd(^6c>j%)0*xf{lsLZ`^<98q?od~0f$Z8(wo1E;K$NK#Q4 zUmZtfo>@49%P!L|hiudy`yvWSiM?7QWW&_Vb!2;|CWTJcgptf<{}Q2SwtL2Z@OKFZ zMa>e}f%9P|unq63f8A5r6fY{dZ>57r#uc5EJknj0upM8^GEV`2BAzH)ib@sL8;sa={%!cLb%%g~mqR9&#br_W;K@JVT*a{6N9lUIOZfC^viZ+G z4XD2QMJAg6lOH7?$;4T_-uE{-T>Ii3sdh-$y;44!{gK)H=NCZ&suj}p>dw@c*Z` from here on + - Zone: at the time of writing, for each TPU type, availability can be found + in + - v4p: None ☹️ + - v5e: `us-south1-a` or `asia-southeast1-b` + - v5p: `us-east5-a` or `europe-west4-b` + - Note that availability is variable and subject to change - we will refer to this + as `` from here on + - TPU type: based on the GenCast model being run, number of chips N desired and quota limits. + - The "GenCast Memory Requirements" section above details the accelerator recommendation for each model, where, in the TPU type dropdown: + - v5e → `v5litepod-N` + - v5p → `v5p-N` + - e.g. ![tpu types](tpu_types.png) + - The number of chips `N` will correspond to the number of forecasts samples that can be generated in parallel + - Some zones have limits on number of chips that can be requested, e.g. at the time of writing `us-south1-a` limits to `v5litepod-4` + - Some TPU types have restrictions on the minimum number of chips that can be request, e.g. at the time of writing v5p appears to come in a minimum 8 chip configuration: `v5p-8` + - TPU software version: based on the TPU type + - v5e: `v2-tpuv5-litepod` + - v5p: `v2-alpha-tpuv5` +- Enabling "queueing" puts the creation request into a queue until resources become available. If you do not select this then if capacity is not instantaneously available at the time of the creation request, it will result in a `Stockout` error +- After creating the instance, you will find your request in the Console on "Compute Engine>TPUs>Queued Resources" tab + - Its status should change to "provisioning" and eventually "active" + +**NOTE**: if you see a failed job in Queued Resources, you may have to delete it before launching again. This is because it is considered as quota and may cause you to hit the quota ceiling. Fortunately, this will not consume credits. + +Summary: + +| Model | Recommended Accelerator | Cells | Software Version | | +|:----------------------:|:-----------------------:|:---------------------------------:|:----------------:|:-:| +| GenCast 0p25deg (Oper) | v5p-N | us-east5-a or europe-west4-b | v2-alpha-tpuv5 | | +| GenCast 1deg (Mini) | v5litepod-N | us-south1-a or asia-southeast1-b | v2-tpuv5-litepod | | + +**OPTION 1: Via the command line (required if requesting "Spot" TPUs)**: + +- As per https://cloud.google.com/tpu/docs/spot#create-tpu-spot-vms +- If enabling queueing, set fields as above, and set + - `QUEUED_RESOURCE_ID` = `` + - `node_id` = `` + - `accelerator-type` = desired TPU type e.g. `v5litepod-4` + - `runtime-version` = desired TPU software version e.g. `v2-tpuv5-litepod` +- If not enabling queuing, set fields as above, and set + - `TPU_NAME` = `` + - `accelerator-type` = desired TPU type e.g. `v5litepod-4` + - `version` = desired TPU software version e.g. `v2-tpuv5-litepod` +- To request a spot device, append `--spot` to the command +- E.g. + ``` + gcloud compute tpus queued-resources create node-1 --node-id=node-1 --zone=us-south1-a --accelerator-type=v5litepod-4 --runtime-version=v2-tpuv5-litepod --spot + ``` + +**OPTION 2: Via the Cloud Console**: + +- In the Cloud console sidebar, navigate to "Create a VM>Compute Engine" and then to "TPUs" in the menu on the left + - On the first visit this will prompt you to "Enable" Compute Engine API" +- Click "Create TPU" +- Set fields as above + - N.B. The UI presents the option "Management>Preemptibility>Turn on pre-emptibility for this TPU node" + - This is actually a deprecated field and will result in a `Preemptibility is not available for the selected zone and type` error + - To request "Spot" TPUs see the command line method above + - Enable queuing by toggling the provided "Enable queuing" button +- E.g. ![Provisioning a TPU](provision_tpu.png) + + +## Preparing an "active" Cloud VM TPU + +- SSH into the TPU VM and port forward to your personal device + ``` + gcloud compute tpus tpu-vm ssh --zone --project -- -L 8081:localhost:8081 + ``` + - Where `` can be found in the Cloud Console, e.g. ![Cloud project](project.png) +- Transfer relevant files onto the VM (this is required because as per https://github.com/googlecolab/colabtools/issues/2533, mounting Google Drive or Cloud Buckets is not possible in local runtimes), e.g. to run a 30 step rollout of the 1 degree model + ``` + gcloud storage cp gs://dm_graphcast/gencast/dataset/source-era5_date-2019-03-29_res-1.0_levels-13_steps-30.nc . + gcloud storage cp "gs://dm_graphcast/gencast/params/GenCast 1p0deg <2019.npz" . + gcloud storage cp --recursive gs://dm_graphcast/gencast/stats/ . + ``` +- Install + - JAX, to ensure the proper runtime libraries are installed to access the TPUs + ``` + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + ``` + - Jupyter, to launch a local runtime server + ``` + pip install jupyter + ``` + +### (Optional) Set up a bucket to store predictions + +- Now that the zone of the TPU is known, create a bucket, where + - `` is as preferred + - `` is the zone used for the TPU, using the names listed as per https://cloud.google.com/storage/docs/locations#available-locations, + - E,g, `us-south1-a` → `US-SOUTH-1` + ``` + gcloud storage buckets create gs:// --location + ``` +- Create a Cloud TPU service account + ``` + gcloud beta services identity create --service tpu.googleapis.com --project + ``` + - The above command returns a Cloud TPU Service Account with following format (we refer to this as ``): + ``` + service-@cloud-tpu.iam.gserviceaccount.com + ``` +- Grant writing permissions + ``` + gcloud storage buckets add-iam-policy-binding gs:// --member=serviceAccount: --role=roles/storage.objectCreator + ``` +- You can view the bucket in the Console UI by visiting Storage>Buckets in the sidebar + +## Connecting Notebook to a prepared Cloud VM TPU + +- While still SSHd to the TPU VM, start the Jupyter server + ``` + python3 -m notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8081 --NotebookApp.port_retries=0 --no-browser + ``` + - This will generate a http://localhost… URL, e.g. http://localhost:8081/tree?token=bdb95985daa075099fa8103688120cca0a477326df24164f + ![local runtime url](local_runtime_url.png) +- Connect to the local runtime in gencast_demo_cloud_vm.ipynb with this URL + ![local runtime popup 1](local_runtime_popup_1.png) +- Happy forecasting! (Don't forget to delete the TPU VM once done) + ![local runtime popup 2](local_runtime_popup_2.png) + +## (Optional) Transferring data off of the TPU + +- Run the "Save the predictions" cell in the notebook to save the predictions to the TPU filesystem + ``` + # @title (Optional) Save the predictions. + predictions.to_zarr("predictions.zarr") + ``` +- In the command line copy these predictions to the bucket created as per the instructions above + ``` + gcloud storage cp --recursive predictions.zarr gs:/// + ``` + diff --git a/graphcast/docs/local_runtime_popup_1.png b/graphcast/docs/local_runtime_popup_1.png new file mode 100644 index 0000000000000000000000000000000000000000..50a3a7b87c8edce399bfc76a3c5db7dddfe91c1a GIT binary patch literal 92915 zcmeFZX*894_&2Un94V5Jq|y#ak|eVz!$ziLj%br9Q!)>g5JI*P!jb7n2qDQ(j*uiI znF&dz%rnntJL`GzfA@dZZ$0mxz1BH|eee7J4%c;khU*)op{BTtj-8H*ifR{8>7o`D z)waL!9~bRb{A8n7>@2=)yCp}|rlqCr>{su@Kgk~2dd^y=u7@3*>@BQp%?~@fJD4Aq z)zCOiMMWEXB}|2`N|65Ig0y^?;Idv(*yVeBx2FA7vhmDmTtLW=?ev4d<{s&B>2E`C z8pqv=U#$+teQjOZ{m##% zqrb|#_hJO|G3z`Qo5G>9MSE-V&V2qxOdp|svRlLJuJNpA%lbBX&$jrj@~1lDXK7R= z&YxE4AU-4sxp43po7IqQvx{C_Wj;2sQZ8A)@WWYejZ`zmr7tq$vt1ks!!y#$aEf^oFul@`&tB=~pF&=o?@=zO5zW#H9_BG|JhAyK6<_~JpOQ^z3<%x@25p{KSCv9Hz`bUtJwZD5J&w=cm}@Qns9nsXu$}+=0*!J6Q$(HOqU`_?kV` z!qW1ef`ZE#EmD8z4jed8Vzhe;>$RPfNpW*B$aZN=l(Mf_E0Rn8USg>tNT1R+T3T9q zfSvv4&!5vvULx_`O`cu1gp>)oPGEufO}RKL0!Ho7ZqDEh{X%lBPB8 z{3hE}KRVSqvToLP!jTpChwuK}S-BcnN+m;j@@_lT9i9vgUhARcnVFgE*RMxLMoKs; z^4ITXpu9XKIpfj`4YG{{dV?f7_6c87A-@`m5sFn!(~3C7@QhdBcG8f5pB zx=xl(mK*Z3+_-ULczAft$Co?Olr23kFR!2=GTkcpS#Y%(cVyy)WoWC=eFT0#xeC*c$rsydCb2ZQ8W?uK4FX8#buGk8ck>SCM@L4)-2U|8 z)xvR+oEi}k5r58J;M6Jmw#vjq!k6kv`D-f+4D|G0eKMOY#l^+PTqLRl*+Uh$SCYa< zXm%Ji9}6{Qi&CCzRylQrZ(rDev-s6<9rnl_S*7oP3DEM_5lY`T&&vc2rH3HpuNW%Okd@|(bm<~)!BI<=%S9U zZqnNIE-{^~!itKGFUc2JLgcw4_tE)l@TRs}V|R7ionxcN(!eS$xy2J%r>d&@L^b`} zH=VG7R|eT~$^U5c1zj|Ach_gUKhDln*M%*UN@bAS7Fq2(5@WCuEGc(hmInAnL z+-0)8Aj3$2HKf|Z;Pled^OQW`_DxB>_j*kjqEP+Nv~t>ynI)mvheiVD;v2XliOCw+ zLDvM`CVR?aV`EE7N*Zpowv5bTqcxAb>YeN@b+x~7Bct6motH1E5qGD@|H1KTgT*cJ z1imQiYu64l({0~+7mso7Mpb-VTtHwT;#pSf?++h7{QmvCarjKEs$FsCxJ!J)&%1Z; z78VwsKm16Eke-p@>FH@`XjpeMRN*?dq)_ba%33Fz1A<&YK!Dm)(Wb60Rl0lbi9A_? z$B4lSDTtY7W?YQ4xY>?u%ebCsGtTc@DUqcpi8olXvDepNK2M4K@6762ckb^yTA7#- z#V`&lOUp-uac^9D_(6&J>H1R=9J@7{vs$f1;v0sX_@b0knjVe6@Lg_6 z`tAJ2rY+~BfPnOx)cdLbWq1arRwn^OYKMY#zZg_e39&((STa7-`A4Eq<^{ z5L^8A?c27t{7gY#gb#+-;A!kc;_K+`ZU#q~vf;Nvv4~g#F-P6i$f?%^7pH0tsjI3o z)A?WIl6^zX93n52{Qb+9%6wZx!_K-$9+~yUfo046>20I9y1vpm35n98BAwCTzJ-Ca9gcEv+=f3^+QGi#V0R# z_^^e3n4h0dPCk;RMby+TG2(vrN-I-|tj!noW6F(}nMpokP^8Sbx2LBZF*U38xlY!0 z+1lZmDK{@`YwPRRkH*ydSM}S97-s9(h^doPRkxjVw4iI(qw& zqq)$2D+X0fZRw;_QObo@>j>Vw%(XK3@n7!-wJU#GQ)^9pa_lO4OoQFt+)SC@Fckum ziSOFn*RR@K5r<>Q%LZv$OCL|cm`kZ~E-oCQMBv!n zyU`k!Zk3>Rqqq~vO*xH8{&bY`(wfh7Tkh2cl8}D3PN>4(Pz8C-ucLD$V|Obng#SmQ zkGS?{v{)5)x+kh5AuFYRyHv9Kzrc;ey0*3^B_*YPi;+m?=Apbs`3J})7b5E}8r;d_ zw-IDpURgmR9GDXQE0YlLktCF$c1}>xVewD@yLYl#t;$>X>mM~Q?nDKIzq*`>QQ)EB z$y)5TD{dP_0vudM*@RIG{oo*KM6;z>EAvr8XKfCPL3Xt5yn^Oqg_P55^r_(&)3h?P ze{q$Tw}zePvM({}iI%Jpe26&C%gikQ*FDBig>y1>QOf4#=JJ`^F1K&jk@BqzmFZs8 zrSaGQo?BjD_bD@;bSb^WeK}b}&&0$;Pwz=mQn*4&>iM+!8#k@25CB6%k>8_-m)s=F zU%q_#=FOYDya?sA+SA>33+3J`pDYS`Cu0b8wjwBIroAz4g7nw(^w`iH;aj?6UXnLTIZyA? ztjg5`<^^Y^t(O11HD#B15k5VR*$=GTUhgW(c)(%@M_PI;P)E|qiyV*V^%Xpt8 zM3l=`P7`0c3IkdSco2;p8EriRNP>*Kng%T z(3W#m!gpQLiUGl%{lI}KQcE?-n@YisZNd@*8dx8$MK zdvcq(DIN7rL;N|vcYO!!-<}o~ReDbQLTA)-vTWF?>@q}4$llu!Im6h3JvMh1nZ^THi+=9Q~F(?jhoqXrixA z%xHPW`&|2Ziek7B{jGryp@_Gw77s3GejQHoRcmp^a+a)6;T}02j082}9A~bd(UPgn zcT=P@`9ffhNkwA)kW%csq>ZJd;-VsQq9x%=WX>MGsOJR*h=aK40J-FbZAV-8kggBg zA}Et)J-l6A7Z7Hz3h3tPJyb|Z*7zqcugR|X1aok7bTn>*?w$rs*%GCyy|#j#-zG)1fqd~!SnFF>G6NvLov$A!Xmvk+oHky%6DV^^JlfN0mPb?5n=XFLU62VI%4O~ zpBAXft_C(+7dA9BV9t*oJz8QEW_!AC%I%@TQF%=y1~W4=5{ZPO2=VZMq1m{2MX!vw zxblw|btEiEk@!mpRIS#rsyh5E6*ZSW1jW~PUzTG;xP-B+YU?^s**7=OrXRcLOMMWi6Re?J) zqO85kuDHqaRp)r7wlbNJ(K6yvdX_!3arnT&gI5g<_+skWL(@?TW*7HpWoiq=)Q`>u zRP~>{bR>QARckc&`#0kK0CFK}5>$cRrLMWgWpVXGuMBQtp#S`N*=jvLGNNx{qPV#h z?^p_ppFRypFeOYAK{(E_yKQK^FxMAhrfatqDei2@f6-tuv$9rg zmWRr++}kis>*UFkp${KEYQ4^1&&?F1YvO$+s}<#XMoY<3Et!A?1w^Y7BM`xKGgkFm zQ&S$$PJpSGbnoKgVsWpJp{Mkmhg8s4%sqEx*{@dq|LSO@#X^J7i(KmNPn}9uXRkg@ zydj;0P~urtRh4S4m!^f-M{vt*wGOGCE~{T9hF{5l5&X*JSVr~C(%Z&i6phd8rESOr z+^4$q63N`$-}c9~LW7WMGXhgA2E%axi(wBna*{zdO5*%>TLRNxJN#THx?&R(eU^29 zuHAcgO*~9Dp64}PYW{j=mn z`i3Zak|hL-f^xHX3alWKkl(xWb~zfn(@Sr$rg)j-Lmmbp!D#Sm8{R=or&K*1zZO<5 zBtJsnWxc`|H8eEjy)utS^p{K4$*R-Mz3CgBxSxXLzx#A|ch`g;k^zF+gil$zSbZM1 z?d(;G}=~EkW0+@WQ%rT=w zL0Xxw-o9N8FpQ#RikZ(6;t7=8gXd!qdq5zP##!YxNlUCt zQdNSe0xn#*fDEF^7ez`0mfjz#K-NFx_D@6ZmcXihm(sdM8V1g>k&#UKo6+X`=(cAb zZ?+T`iO-&wAKmW1S%Rr&YR5HL#MIAG_60bE2!Elu=b4#S;3-HFJA$7LI+f^T9g7XH z?fO%8B*jRe#_Tj%Td9;vtNCx8tZqA3Y__p6+qFwSuu#j4bK6FNd+?~JGSbMyOPd1? z_B==M@%hwG}pR}1a zzPQPpHL_dTbTl{TBTfRyMs7u{lGl8(Jp^0=%38iCPOX3wL@n`z#Q6A5qgvUY8z_uh zMkL**Q+2Y+JQoDo#eV2Z79Fh0f1#Br9II+k5c@b7MK9T|nA(VoGaTHJn`ELA0a-ko zN1)y5;TV}Qyap9l<_)(WcO5x~r`!;UP{o4ZuzW1M`)FF@4!>k=zL7Z(j1j=HK#Xy& z?~63lr_Do72PIrS4NN^&PHS&(*W<_CJ#T%kkOF*;G4_5u@RjQYfw*($&Ue@T+||Xl z2N>~BEKl3a6+CzsNoeV6(%<0^UXa_P)zTXG1)y*Vtsn9j%`Tpw?6oiMtSjGEM%DZci21N6R-T}8Oh@UYkaq@q~ zd5D*`dzx^<>L zUm>Nb-$7ER^4RELTwL5^<)N{$v65TEz-)C-y@AOT6>E+boj9`QcRZL$e(Gb?i5P*a z9r@&k=_9POLyj{4?a_jQf)giBR4ljVCQgeY0`7L{>5$891NlYSRh;301YKE15Kp+U zUC-P7bpA?`Kis7mX;JrBegTjPv!T7Gu+_^+lE>}zB;0M56!hShGKd08W; z<0r&pGk~s8P2y)b$D)+6XK#^6g9xyNua5>JC>ve9irPfI zs{c>@sS4C3Vqw54+1c6R33fbhXvTy?70w6?JCA+Oo1b?nGgfbL7KsO-L=gwVD5)_< zIZZk3{fuOlAV&>iZlS?VVRpn9O!kEfemYr*$14kyy1dLu-t+qWEJw4AK`{_Fw)o3U zAvwmWrUwNDVIkKXG%Yl^s;B3G0=~F&VCt!qr;CS2$K*ofhnpxVGBSqKVvM;=G%@he z<80-u`kXvOvMP6rbs>aStis^p9#31#qg&n_tr}GRXcgFlAV57m@w4_BIBY zH+aZNj59p^;ltDxtBPLxg7(_UTO~%Qro2~*dM6t#3R0iN47)S=1y}!AU48rR8XG4k zcR>K&1fIdEgjJyMi*Rz8sFMv$%0%r@x#~EW!H*GQlN0@|%>D3yk+wx%bjibGp(MXk~Vl z+{#Ext0oEU5B-3Ww8OS2G<26%;NPv*65`^Z&aC3p2u$y;337(Ze48q0f6+LM55COH z%o#rXHS;n8UHS-kY=lp8E|Ks}W)-^*s7&?JCCp4o_XNWFgJ()~{=|k)O(}}Rqy4DI zYS{a4(f&sxXL=K4fLpR!JG#2wjxWd=%vSuWe`7BNuOi;TAN;1ka)^ZhG6BHE_%haI|*Nsob z5Uf%*$->F&><(kwy7)x`zX1B3M9HP1;o z=IXSxG{9C@hW5O9|AO-J^=ej=*@cBj(az>(AxTL|Y3UrTOymb&@B#q=J1Rg_=$c%8 z?I6oYogelbH&D=7sC2RARaL*}?cZMkUEC9a(b@LzkA3pQp~MK0V0WP3C&p7JPhKte zT9W5y*UV~lDZQGhEkut`^O^bCfJLa-OHyIH4?yYEp!ipYPWdiR7XQNcPP5!kK(=t&0DY0uu z6*9*|g@L{Qp2}z;vRp@9)8Ql~%P1nllloWH1rY_G*GLn2d77JwhgcYPGASew0j6;n z8(f(E5`l@3ph9+@=qdpX7edjyiuvtIokyBKeR{|n-0vXD86G(+cIM39km?sYSt#19 z3sXDo^7IDX6JKOy&CP=v==}b@PgEeX&Ng%5SY6NLLJ|hMqtFH^;^|%xZlvTMu#+K+ z^1pokTYiB0AmWGWW3W@et|OGp%Wc*2Z*OTqh4b5jDdy?Rb;N008FjglfW3`P3U9D^ zyMZrNf?8Wo4{DJ*<;t%J zITezPAl=0~_a4}cQ-DNsOJ0Y(XPdAmIy*Y(cU0|FzazycW@9VJrWakP#~**nqe;aH zs1%b9Vhbq@)MWe1(68Y?ox}o6Ygu+s@O2@XS!?HO?#PjUgZ6^M0H?+*AN1@=9SOxB zGLofms{X!`ok@-7&YT%oSaF?-uwzYWT6}zYvC+wn%zVLq#gynv=CQd^^;wVRpe>!Zl(V!LjDE$q+17|>ki z=0jk2P<0j;7gtmyX+Hol(jk0OUK7$-g9X`4UK9AL*K6jIAe(S}0}zF?C`TcC12mk) zwRP4!(Jg&eo7`keN0*AJr+i8ifF|;wYmd8PMoX(&Z(C?oRMg1GNLW~yT(Y2y4A>p8 z%8_-1`g{PitNQxx3*D~z(b*?=a4AV_pSE4}D!(=Sx#9C?ivs-;BOhon(9ycP#~nnm z%mxPuBirWgVba0swNVd<#$YHlV)zrMUXBk z0&}UADd{o>!gi&{ZvIrXk$~&)=Ys*Jisr;e|Jw`jaVWoc5~3RE%$4A0`Rx$-)~CXK zRVi;>&Dn3%T==T!!oOJh%gb*dfKZrx6lE1(Yytp(P+Jq2I(-JQIp4m0J09piOij_` zCEGp@O=C8AlDyAv$$8NJQbv6YcRL3>YCvsRl;eTqFJ^XSTDN$5XB;Q9Qq zZ4!DB(K}0pMD7glr`-NqHmt|q$qJt}aecnodm=90{Pzw6rf0P-Uyb~R3o_iNM9)+7 zTH*0Z1S0`mXxQWZW3MsOQc_oWnO_-3seq;FjUahfPv$>WLOCry|D2wf=cW3+i!!7T z7fgO$-ng;cV;8>F$02~%+RM#SHU0beUHC;F^TKArEWgYAR;d*A;@$V|+wR=tfd*JA zwS($TcZSkYQssokKg^5#JH9KRG$#-{#!WF@tAFy!Q6=IyD{a*c>NDd zETNo~3m6LmA!O@-i3X9Eoj4gtS&~$L?tvFUiAdq@_Bc!iGEtl}|6^;--=eW{@2oJ+ zRGr%-VzJuq&RxMi%`gQK2QiwCE$u7>{D;uEVL)*zff_07#06? zI-m5dVlU0G=x0k|_$ETCQJf|Z2Q$TKnFXjRyO(Dd!wH^f>8QP?F0IAsN5 zP!*!U)#DdI;N9zYzlOs5yp(#XHhgy*4NA@ii%2hQx#r$jv4OW(y=O?<6%0L<#WN)g-U~)Bh zL4CP`YeyU`%?_mv($%qb9wyuH`{(h{;g z1>}C`W}ky7#AYO%hV*pqXmj_(VRz$XjhFlaJR5KKreAuv&c@|B4h=D8NpYO z&B&-L;#Tj=lX_}dt#7CuJI!JtD(dlrss!6GIXO8jnqwrelcGp!Wv24-MJTyRat4@w zsHvG*T6O83+?)(eU6O}07FeEyyOfPABk$BM92HAF&mTnts3z!;#7()}2-ti+cM*))Hqj3kz zk*x9Bp?~YviJ2fQoGX^j4Lh!v@&@rA-FvKmBt!eQcsB^I9-oz)87*})USOCAOx0!) z1ETQ_=~|hnHlRR2;1Xnm@KY)}gbaO7L^L!1_xbAO*5)t5H}mu;_Z;7#ldXdb z`A&tH{5vqEoF@1CE4W0@Utf-F3YzMPlI|nTY2a4JJuZgZI6FK0_!tYsKoT+%Sa^Mo zmHG(B?u{y$)ou0p)n=y?)Q}4g2gSd3bnp?IF*W`> zS-~)AWipotM=1jz#y7b1Rc^t`EbhF*cLi`uGRbrkc+?L;7=c{u?JdQz+kX9TxQi&{ zJ9rhyZM@7*q8xzDB%z>ZY(W>Xkw8C>FT;CMS{gs_T7#E4=py|NznTFIEPs7^%Lr5u zkOr_GY+Oe3vt1QKmLmUATEWs(`ILR|GiJK#%F4>D-13fZ-=Z5Vw#)u!y6}#^PT{89oHiXu(6apA zF$P`41Ov-lRw-tdhqnf>mg1nx%~eln`Zebf-%wOuED=sZ5X*bKOB>{5|O}RmCkBN>JxnUVL02c`0NzH(jjm^<5=NM`EPo6}@ znvB7q&q}=07t%%uzdI!SweiTTW2t*xPA$fz<{XIH5AR4CHi(=Yt z_v3qq`2_^jE?ol818TizXQ4_RYwE;=U1_(Pu?4fde2V6@mXU!0JOE@u`}jP>+n$~t zJ${yuYRfn~dwU*T5IS%^d3cCNDetEbkn)^|?E8Hjw~y5e05sWEQt70CJ|QOEZrp@vVN z;HsqY;aTm+s)HgA_Kcm44ZwLYB$PlL7#N698jRhw4fSnbT^@gyR%WePJ|tU&ij|mp z1QOuS7rD8k&Tp&>;h?EwX9jL;Yg=f8Lk1s0PPH!N3pTraxdlm$E26eFOIK!PtN_|h zwvM@wzPfrNg65cuXQno`Ie)zt^D!b*jU(JvZ~p7WXOS;g z@PH_=lUAl>oX@Nwk*rObE3=y)U#=?&9lJiV=RkNWTQHfR{ESy8v!(cTyc(U{apsbS znKT}Wd>H*ojF{!kg<|#iU3^Wm2t?@lU=FjhPl6mduJ_(jm{ZDQ_J+(#EHs@{^3mo% z>n0u}X-!Q{09@T~ZmWR%5we9-X3&WGQcK~8X|y>5r{wqc_WeN@Yvs-*4ig`8=Waks zl!4QM;)XisSYpJZ_EZA{OW5xBE-Gr@rqgi;4Gr#ES2w19=!a++^TEPbujuLN!GKIj z%A7upSgF)h{pXJZm|foBUStjKNca5V?A^RkAK|G$xq#bzEZhL00Gumf%B0jKTZ30J zel{GX%b!6AD~gJNa>;-5LS0+(Dqop-beI-Ir0;^OZ;c=ftjAeG?j3%~?>Sf-A+u5- z(=y^Ld}(EE9U#<5pbb0y%a>~f?P^KVKqomyCHZgO7)U(cMBL2s@E0OD^F_h&_|>`) zWR3@AyKed#WIzgol?KmkU|`1f!J+_rE-oFSIOr$qEyZAUP|y%YL}FE8qtxZwm)mAj z+=*nKc{!uD-IOfX<35}aW8wo=U77qUKc6Y+;-Hhet81}Bij}psxYwcsKMTc!DiS|F zHb!WWtvllYYS2NHgO=K5(LO(_ZIohHF>xQ-H=W(LXI}`#Vioy;RpMMx_Z3GDrg)R# zgEW>wu)A_)`dNg2jF&F8Pb4u<)hG)@X*eYKIvgD5RXT>pC3rsP@$U z05PbCv1e|;jmR4eVY0*spxeF7INBV`lPv%;tU%QIIrqq4v2`RbY54JW2Ugh?7x2S- zsZ#C@=jZeX_?AE<@Yln)g0SX2gHbq81Zq-6WoP3I*!>ab4(mdAV?b1g*Ua>IPM(Cz zG%*(1B1AnC&+$Sq&1B>KXUodU`0GEC@=uj7{%);*2T{bEAIg(lMsL zw76E+S0;;Jy%OaJ!%eOJ)~$%!bUA_A$ku`80kKA1X6=KSqX)osAIOa@eOSL z;dm%qf#+Vtl=j5bvT=m|t>9-MO}-tuxY6gA#zV;rDNk zwYfGFysg%AZj#8N43jLf#8_t^-`ZYE->rqBu zMTo3J1>j<4_DrOyT9E$uHcLA@?h}pHN$>6*#vH-851Rear;mc4Fh#`Lc{j#?oxP-vLndLqDu|<;^Zb97dg$vu{HDRWPNh+&wD|~RMJ8{#6yG#- zTZ5^DqdDq?Zmq5ab{erc+-DZ1A(Z9b(!H>f$OoIZTHluBlyUrV-yr+jh_gV9jg{3% zKLa7Yfjq1F`zMq>J$?j4Tagn{%K1{01M6bH$!ZxkC@@B_zsNs(L7N-t>mT9CgcU#Yxwq_xbxV(3*J&MX=E}-)F)8*q&iPe9KL(zfpUE9e(KQT=Qbzy zOVEiI*D7;enj#bIlr0}#)J)ZpCG*Cy?EgkO@bw@5`PoEgi&&=mtaqJ5%EpX`3I$*R zZ;&d$uJ3_v-!SuwOhB#cyRqQQu)~j3madjD;9joH2W&H47Q!f>r6vk;FQE&vxX)slCoeF{qgTYSloTpe1 z5})!0rw+5ms(Q~JV5i>^ssL00jlbCW!KWc7*bRo=ffLV0wX6s`@ z(BQUc?9M`+thJsLsb$5nwUfW#kulEYkwLNC;HtQ>XDRlp+9UVgmB#K!@<&CE4>^a2 z^N%BGcvxE2UK6w}7v>B{#kaD$Gz0I(+&3{~37mrG9u*Nz4MSn!EK+A^afmGi&3AQQl4fSUsk{x-b!To1ff(~=tI}3{gVxz72EhwxQ3jqNEyV#}O)LUWnBym-aIRB?+Ps=17d|o5JIT)KN{G;d6 z1?`feqHc4)TThmHKX3gF1X;0KUx7Lo<>2D4lyC?b03JG795oO%oQ!i7pkcOJGkL+B zQZ>3a5gZ((Klmg-#ej1_N-^XvQ;JF^5bcU3VpTx~z>3n4>5pPlJV7l}dsR*N7YuSB-6 z1MYyRvM+kYcGP5hT~n|jK0sjyteILyms@%k^Q(AXlE@{Q36Dzg{KYQNF zN6cXl{a`C%pJPU1qc~Get^YW+jwrBhzy8y{HbATKj1vp`J09|fOZcv@G7!4$(jRnx zrS{3)+50HyvC-fA@8$$M%Q98>-8LOskaHUh3Re~3JbIG`*ahs+25&K$(jTM4x39Tz zbY0Fce|%cJXNtu ze##E|^w;Dh&$yeSf&%DRfVAfJcB?`Il!}yw8RNHZrJZ(8_?<*RNkO6`xPKxcd7!;(&t0b9v}H zTf;6WOprxa6AbR?=Bij<9tA5X)i4ln6LO${NxAB`LQ~p%lpgqJB#C!k$(w9)UOG# zu}AiDV(A*^x`I~S;nxqhwRi%0kN&!M4ZO$>MkH`h0Yerlh147wcwfe{O69Oh?bo(2zM?Q$n#AQ)np%LaNyUV7NuNu_zE@ zq5{>Fy8gJ4TQq~4l8@`GXq1z%bAaPgVDkE z_8R(68Y%gRbgCC4E+X=+r>6(F9e3xmgm}V%2ySTQ;P4Wt(MUjDP0b8)E0#ge_wVQt zpxD9TOb47x;U77=a-{dI`x%bi&?dmW+Rjh$y*;#NlgHmfZF>9m9F{A*nWd$j)`cLd zqsOe#>4dh4h7w3y)gi2Sz`AVM1Yw>)Q{njdO+Z6CJNU1nl;yKpn}rh^EKnxjNb)X% zhqr8-+hX9y1)sr3#&0Rc#-P1C;2bI|&%>r75>IKaadzN~nqFB$fi&Ay2ICjn3r0s# zPJ>7(-0lJcUrG~N{(xZ~b%3>l)nN(DyiFmRCIy5VP+}!U-cg~nM)~;cC@}nGt+0RH zKya6`(7`JPZP^JKC+ad}(3G{hHtB0G6bp_I)B?J%;DLn#U}c2{%n|XinWfr*-OTa; z(yv|#pFV9W6pNw;Q3De|ydu!PcrLQ;1zIi5@+PG+*B-6^Rk&kk_~LoaVumV)=QdAx zgMswo7rNb;$pirYn>{IxJd_gN)(-hSzRza1bV(dI!DTO5jVFL{I8 zZRckiWzZ_<1@E9DwhO2@ykM6?N3b?ue8c?PRz9ke%zG?k#s0S!fK;+w;7rD&i|#IJ zqZ_4C=DAL|vn7p@#^YrMcw=Q`l&fYw{LWS{a zh^L~UwTojWAJjg~ zh@6LL7olX!iC())mo5!{+~IeGjUJ&z#_#5@q+9A6W}Pn8=`urC>~dbF8*NP9l;pqE zLMmW_;a4bDsb&Zr;c&+Xi7dIR9E1t^$mC9#$y-MBqIU&WA^FCtelRs}v^3DY=-rszrDtHk)hddAm6gGpouQR!U>J4Ni~UgTxC`oNt<1rhC6_mk zM*L^`0}pXwwQGDWvIy(gu?xKcpeB{QGw?N28=)#zb;3@HSO1@!wow?2^7OnG$$K`d zs{H{(0uw`QDHv#k4wwV4dQF+Ga_s1OJ|viHW$}E!{GXWwQRz!9P^`!FqSZ1$NTs}` zN85Df%#rHSCmh}MTNaK;ul~6}QS5Ec$TGrOeLedr#y&nn$u;6IErhx~=Ooa=7M{TXd<$mQEg_r^UqZ!o+9aruA(?-F{y#ls;(!!Ie1Ch5&L3UWt=3n_sok9QGn>{= zs=n=!N(2{A+Qze#N2B`xbIOIYCc2;RB)87_c8{MvU&d~56(TFuS22o!Oml&53zeH} z^2N;qHhzC}{oH4Eot|p(?jBc9T>Yigy6>6S&D^DD4^!Q_I<=ucwrPA|Z-b~$OEsrp zk*CLaV+)lI?{8h2*dAx6Vw$Zo{1r~_+}Z~9b`0D8POK#f#j0hbi0mVP(?})?(TjKT zv`6l9_RjC)ClKHFYCJ5VEu;No`0Y>Sb+lqjzs4`W+`4n?0&STvv;4T$KYvE=bUrEf zCJxBqH*tGZPmRy@h+1g7_$qMbHS^sPmLvBWJ$(M;^2vElw0S$^&2Ij{T4Md- zk0kB;^Ne<_A&i@Er`ut4pHc3YSmklU`}5&6wymYQ7i4i={H(p7UV!p9jMjW?4tpry zX{Mm<^IXMR!Mm@M!8GdxBXwZ)F)$x zzAZ{*w2W{^20cSX2&?i%&8|_|WJR}?@#1ZF&r2s&phAJ2F-@y|bdS^_pNb8%s#_O& zpoQ7m8oq?Fxd$#60KSu<<^U{4*45^}Z97VN*7-_ly4RacqHus9bPBVB^OD!ZA0;*B z=H?=ZdncYuJSeY8EnAzvT5BXnCJ@Qt?tki=T3gSGin4{*z$}(=h#lj&r+NWq-AIEleF&zXyR45Q<{`CDXo0N zGoXvpLT5}x!>;Q`r(1;$9JVZ1rMR%gvs(G;og_IMhnyNmKMC@{^hW6~mC0?9KDt>5 z!y~B82aZUCx3Vo;XLNG)^W^vOdE6rM0cE1U*k^d}aUhaYd4JoOtaMaX-ht8B9N7aP zC{4?%pxrV~+M(}`K{kuLBZ-s>)IaBDLR;LqR_wb{Oy~cjpiz3Iz0oJSj^rY?SlDI^ z+=X7Tx0o(CTj2#Y{wBC^JhBc42-Lx4eV)16cO53YA0emaJuaS#t{+;iuUJ)q`ve^n zLm#6)%)qT5bP?(&(s6h1=3?F3$-1Sj-{I}lhXb&;0;#}zIS?)b6&8(lls0iFjBpQp zZ*F$NQ8O5Sm}u1PKFXLd%Y*v!mwT5#cJ#*i{Q>==yLSh?59oCgY|qu^+b;WEX6D7I z-SV1$cJztbt;}?Hbx~T8f!^T=#?b*)dno(C)6#_6sLa?Gq-aXhYf#pJ>_S3Fyr!o6 zYaH|WZ*f2!8<=u~y96~+l=7eRn??bYXh=qFqyBP7B9u_z-8^+2?G-Mi3F#vZ4VR+L zO_$hf96=Dn5e|=mqbDp7D{DTWnIWv@wm~GXBbn*zP#O)+$jF#vY8NHN5{PJNa26kT zDFvE0j~xcP*9vHnVsXcm4H%@L9UZb-nJ^~>eSd@fGBiBFOh+a(31>H1qO%C;CR00~A1+>qLC)f9 zqB|=<0kzvg<_1SKAgYO@Cw_o%7K-gEvwr4yiS{1t4nHueS*=>kXgq;q0$sl7OPfDx z6$j-2CpehpRh~sj1S1IY{XcS=nc9#OhG#HG!OyG<*&An2Bm2wAGDBm9OA2JBm-Hd} zzp+7ZjLr;&pf|`?V5FAU+@#ZCjzZH)ZW}}yNa&KBy{CObs<+1EObNv{{8W^-!O|6p zM^7kLzBV7TyjTLj5`-x9xiYEy(dKRyg~i1v@`9hy@3?^2kga2g-sm(f4GnZWVmH2JVV2+S zCkqxqAO;3b%vZNvp!!yaJ!W}qms;R$^q|@SZtmYdvMKM731SJKVL3sEgW$sk zi&4<>zkbnDZ^fImGEsd3s6wtraMH@8-dpycgpc6q;(~YzMzHABtAAR5o8^g0Nuhaj z96JckfcpIUp*QHFXtLxuco54OyjEJ8e3};8p)7?zI(C64m|l{68ciUER5Q_1qtgYD z8v&fZ9zDKj?}4LVkjjk?J|37)pqXP>$6a86Qxb$;1aUaDW&j2N@Ed~h4U7)k?%tu8 zxUIofgdU=SVGA(Flv4LR<3-mGJRv%xX)2KZm@c*8p~}jnnCWmH1P-I3^doW}I+PS2 zAEJ;FD!g$Xm2O(vF=o2kuC6#e1SKx59DD*`0iHfJGBIJLHeDFG6*kbfcRRIXv^maS z0;z^W9pF%54>#BUHar8(s;Iaa!dikF&2cZFR%lo(A!qOXy(w%1ItT7Pc<*)~3hW4_ zU(RnRcBV*T*Z|rZ2If2p4Nl53!WX)GH^vm(9xp7PumM$WwBISJt1}pNA~=t z6ZP^=3b+RZe zCmAiFVPQXKmc|=xfaxnLWMWm(){Gv#7Gdu{edvdMXDb2$=R&du4y*wYYz~p;lBy~_ z1B1Pt9Uk$Jq#LjgGKgCV4Gj&b1&EtMXV0SU#*s6CwP=;9tyL0;NwM;9pu7Z=8pA_F zZ#z3zUPqOUo+SX(v9TAMJ(t%!Jxg1y(QF;x0Ame0QqZMBB=;?AFypWnN*~mlH}~HC zt(7c{|kPe4Hj-lm)|Gv`r;WJKq`H4MNA)NQmmg7<(HRs@bzfcl=M zRkGG+mS-~q;w@jaw*Mz{=*}Gl@$w_p#Pc*=~nKvbd8+Kx1eWhZ<^GPhI_goMZV4hy+oZz@F$cE5dw9yp^1_B-QchBE3 z*Ecit-G=zoWQo%=#KO=B1H)1%H~pg^zjuGP+oJ1|Kx`J?b-93A7M7MT!rgAU2K+2A z?#_8^OdZ*1mGbCQlEm{jYndJLf#s?k-uuHeh8Hlsm3r_|cr5%5 zj1M|a-@m_yb;QQ@Ykd4{d;=Vx$7p|G0{@Nih1_qZFA|}I*yw6wQ*SGRQ}q6|6nME zQScMm0C9ZP^wKKsKP)Q`U2R~xOmf>QAdFze@Ylm`j9rUUfM9xxOFw1w+5x9M-8+mX zLY%t>N2jwmx*;_n_v1d(v~UQ)fcyHo*)}o({r{?&boT_$ojZ5-EQmUcqk4jPTFVGN zNxbq-O~)9W9xHYhf8>=0*CLp zOD1Cf?d4H>@N97;BS}q7EgE*(WE`LXX32Sd)g8MOMI6o;2CdTkRxB6Rj%Zw4daHAv6%Sdtzp4SGpRsJvL z-aMSkb^ZTtkR%~VmV`=WS&B+hiHM@rppqyVib4_%l1x#QmIkd-DbhU2l#q~AR!c1; zQ8JWTluXGG(eHKJ`}_On`Rh53=l491V;_63quuhkKlgpz*L9xn={($IH0_F-E6!F= zPk3fLgOqb>bEtJnZLDf5EQ`!Fh`=Juob+)z}g#->B3*MG#)3zbc^46mlVQmpvhwCjiwh76q zigexQ&*RLt`TG8#C*mt}ke*&VO=6(6uf2UW#Xk7fkn5uakiHy3{znHT&;y@7ez?PW`>ev=+4>ewr-8Da35vj z>EqMIJI?Y{%mY|QLsDwHXISHCy=gUnC*7xy!bf1L`mog(6Zzs4{VY-lbacGB+qYaB zR7H>T_v)+{m6ehs%lHL6clI@iV(_!t??f{oG>k9ZRV8&9HE)*hZnnbh+mB!Qa&86B4vQ#b?YIG+_emCtCl$ z?iKl3Q_RF1+HwE<2EPaU4K9lq>V4JJ4%|P(Ha_yoC!+R9Q$(&1q(T`E&+C+a_3Ge% zV^CF$e_B$HQp%q%#fuxHVy|B6p4#oJAHE=Yd>Z;#&SR1of>c&q|C4tYn<=ia`Zj`L z9ONm|r=8}}YsigOR_3(nqx?Zy1+<;JX0LT8HFZ<6H5#00r@f=IecC>A>OcC%lgOx$ z{N3w(k*1!a|2-8k`mVikTb!)rM97iy$&_@lk{0o%A_vIan4FHy*v-x6zW1@(-CW%} z+$B-3-CotVp|&_16}6|j%g6d8eSpJmxo^w%FFNQrZOG~&i;TSjc6vw0oIOiafu!TP zzPw7jB=!voxmp112zFD}7)m26+gfpNm4@7qAS-L@SeG|~UhfronO+$V)8s`(H&+lh zXzXice$8C%LXn!GG-vRdI30Z{^o37KhKdZZjbAPM7)!^!d*tL)$5DAo*1xb%ImzZy zFcudhQuiG@WR(czPz!@Db9Eb=*j?P`CkSYNy(xT6`Bkwr{QK`8vsMmyIMFH`RyL4q zladS_FLRKX6W0{+t^JnELbcRrC~HLK6Dh$NK;H|$?e4Cm=Uv#;fm4=+HEY%{lA-V3 zbNnN956 z;0BozIrPmoI`mWee>z#mnm>4^rZ}V`!+7~V{U2H*7A&_&Sp#ZbJ}H*NB|Rx=ze6x&aEVDO)Xy+XGuGsV;}g-4k{*R10KlQ>lXI6 zhp(?6e*tR>M3OVQYMdFsfYGxV zrxz0k8(%u#^v%?zX(yBq;pg(6l3(ynKWWAFXF_JJ`@-@oiznu`>X|GrT>IH!63*iI z`(GEK8n?W;U-V`_{iS-{Dno(-U)dkk-SyJ`W&O@;#er`Ls2#e0fc7|DpKsStC3sZN zTe}u6loig2?CYy1bMQ+q4uIfuB`i;<)bSm~+QasaUNI-n(_{4B+V4FHli&W{`XF*= z-lW0KqZdkxm_-y8Txyxzc?V!@OMc+%aj99M=@m$@Jrw4Q-AXE|Rz*bU?rE=_Bmq zHb{#k!kd1tjafKuI4Dd<_ZZ3064#L?)D*N*&VS$SJ`yZC#=K_F&$R>Yjhfj@Y=Bbj zO3UM~dv@)0kWZUfID1!j_<;~p)u6T^0z`U*+TgpdQdh%v#lbWGOs$$cK`co1X>@cn zsgfZM`lMX;aZL`5!y02}9uB>+#56{)V5{bI@59xvW<67e`N6$h5kNR!(5s#kiw4b} zXEJZ*eA!ls~0%^x-1^S!CW+n|XMCoE)ozVk`MfSnNzm!jPJIo0*SHoOG3Cy{N zgh*o}@>k%Hg5RnE2*E7oXg+u{EWBV6!UTIQ5=_lj#959IBCzOXcFgv}?A#~IkF7a{ zw*7FUwA`UVJ`?p8HeblS-`hI3HGg@WW$exn z!Qr1)^wT{98`GADz%C8Sdz3xEmj3;B;+KTaE#91IfQ-dra}(BDQQ1^Akx7*hHE_oU z1i#W^97CTzx#(KOn%C{X6k@X6V)nU^Hpc?q9h_{T=Pq#JZS|$8yJl{X7yUe=zHN4X z($Q<@6J`)H8FUMON>dZwMT$L&#ZZqz-KXP!UpHD|Vb}$+hO;!6ehk3jgt?V<=dH%SF^Y0_7vOu_?Hf9UM-2Kaen}eer{>*k}ZFwg^!u4A@}>j zg@bdoI2kN&PL-f!qu(@*u(7qxLY%2Bxp~VLKRRH{2iQ?A_B7=FtKu+7Np(-%<&v_i z18sjbls>O8LvVIXp<>jMS4fPypcdkwbhWe7!O=QorHi%#o4MUSvz|?}w&V_MlF|cl zliauo;WN>SK5ecSPiA`hRZq(*bWmaYIMS3pDDo=%1h%s~KP2`W5kW{Iw8!s%UQo-l z{h1Q)H(6HPa8v<1JBg%IBFRfpH7kt0+cH?lHwd3uX)nJqa)C$zm>$L<#8Sm~K#+xs zrtJ9%hf{cMdgv!xP+1I~ekI|9X`H1HTjZae?HPfv9m!+X#f!PFDGwhhsg5nH>aTO8 zY>i&@SfdF_J=UsXds z0PTcN)`>0xeN>g%`wU|ZKZETyS>Gjt4+o%Nc5^8U%D310_U#*yI_(HoK+eox^z(bA z@KL$BnPa}}#I8LaC?uXb%Cu((`ki`L|DQBy0V4PB`sM=SoS5j@R5>?DrTPZ9kP%R>!6wCz|6Xl0{c{kS8Ie%$vxl}!Zl*jOis+vm?p&E}%(wN`$PDne> z;vDeD?a>k6`#E0l0halSKuec|4S{HRX zb|bz#x0LYY0(&jOoR3dhM*5KBv|sa+mR=2BdWQoY*p3sRpa9vw1TekPShEPM4r~{n zcAzP@yh#-gz-d@KVYw`e;_-ScWP8Z!0*!aCyw8eP)}0Ck5%E;@+ZFEhtqONLaxHCb z4;zmw%zTHvfkm^}VP>JMnMp?&Y$-hRx=5%CSF?S5uhC&c&5Li?FNAQ|VWPi`cN9wpg^}36+*! zLGHhjh5ZgY$Rk>4{T@Bx^?%ex(a}o$E6vr0r>64`u>x5x*9$n^>C%ZL`r>FOq`lpKCD6j z{|IzYHPgs9Iyd3VeGgI@JBK2GBES=lf!hX9$%_5Bjh3M%%_o!M4eP_*Oh2pK!=qm} z7d#oY%i1-$$JF{HJ5^_=Zk{I?LqhjJjh_Hzn1$9Sl8;n(7JD2f2ryM|JzCJl`uY(f zXfa40Kaxi{@1~6#RXo}IfqmLKI&wo+agvkrdv2+LRyB6P_*=I`f(`rZlq!o6GVueU zY_29UY%m(ZXTkiN&;@i9>?4Eqmn@gXT#GgWbL{EGKv@cl4@{FFW99rTs%`1tnySRM zMCN-3L>*#tJ9a=V49b9D67!4O020(xHdOa3^6_@Vu>qsekYnAz_jpY!WTHf2`AExXY%Ankwf?~A zN!F{BqAndXyG=o>M*c(p^fsS!4yPtHY#Pj=K_(}_o=>KKc2$Gd{08wQ#qfZ?ZuD@F zK&s(;48)y{50aCSDQ!j(UV!(w7Ih zyj^R?jK{6$JAhauZn*{T1)iXAz=NQ9u~OU>oPcmm^LupT;DBZ>y|AIATSBmQ-x$Mg}XQT%Y2D{* z@5}Kj?p%6#1<(lJA9IWj>1Ta*@6D-OC!4(pRrWtsPz!p7;7(y! zzI|y3&J`Um$G~EyM}Pzx$Cu^?nW$Yxj!2qacE_IV+BTm)=X=ORCmM0O=LEekt-6B7sd>0l zM=?wS)s*h!z(-Qq3e85yHISw?>j?;}*REaDB`%h@#wNOtJJ!jSHLb&ov!Vn&WQxH<27*T5`3 z(O)KFC@&%ET!IcYw!6RYxf)%o4!+5DkpFwiJFPQUVqXpDdWY(W&E>$u@Mo0pG)Q=DIMDlyG2#WSQh!|P zV}y~;C0fafn-KqOXGi3CLdQOeQ8n3w(lwd3TWxZmY`Y!0S9S1)^DZPnpn;;t^i<5% zGB++NGu|?o)5lz`dc%S^%is;?r)p_*Q-xCUOc`nxarCKuu(r_K+EHq5xEr4uoGTx) zcU=Gyv1zBl3#gg1CrM7v;68uafe6x6m8SjA&%j)xhw~<%Te_Mqb^m}r7jBJd12A{i zes~k6a6_W};??teJp1TnDVqy}l}+uNq&mMP)%4@D4n4!73z}Ek>Tx9Kii_2ILif(E zi~?*D8GB;UmY8(R=sfPF7?WauDBL^nKy)V-=W)Z;%9u={y@UlrUpOUl^wp6jo(mrz zDstX;WL|jj!%_|?2o!5eSrg}}{p~`vD!iXoAAIn8d%~_kM%feWS9+Gs7_%t-?$1$U zzbWP@>s$TWn_!pNf>~c4C&pIl%e+!Bhqf`8**x(nX#p}Tl8$JI2B^k~CKgB)&`>Gt;efmi)DArz5EMka4 z%Pc}8zdSqN^Q@49a)U*9a~0D=oKIy>7MkQR6){8Ub{`GyA6-L&^1vx(5tKacsVil{ zN~y%XKHQ)`?WhW{!!U_M6o$6~*5PJs?@3BV6cP@nfD@nf37Fp0wl9C+WPKM^GQkm) zZh|R!iLke5oLEE_GEmi)EyE7uWFw>&v&G_x`;8J)bKNI?i1d1PCT8}|8`;^}7V3VO z*~u)i(rA3Al9L+}COq9cb_2LFSZv+hUmXJivaz<8xL#00{En8=?5npPZI|>u<8fof z29OAbiQc4G`|a$71^cU-sC+c?DP~|}sSi8vmDV>jH4(*t?Qh?UW0@uqmxQqrj1@r% z3!{{Zy`b&O`s!tn!{t!sV z?l04=Px9*PTfa!%u&?it4(lFQEocC5vzHT=K$V$7t%mHtOvX@JuOI+Bton)1ME;?s zr{*S8L&-GaQh5X63EOoQ6$SSp#TudpVxZ@SA|c7BP!ivd@82jEd?2O(Co#~t&cJA^Tsr>`Z?0?ezS>40n3x+MBPCM(^;qzO^tVd*xXuXyM&jm2x-2nTuQ zwRChN$yspmU0+NknI3Y78n#(e3*WcZ%uhOI`iS`2c0k%`dGdvJVojJh~I< zgP*>#$#OjOE#7AxdAxWGaSjJmy8EBUj1lxV)lB zdyIsrS%jO5OZuJ;cS5b}>d5M5CX9jytcpvFcOzT!Fuw5h6* z^uA&QB46l0ByyQCLc#$oKy+Kh^*7^r_2f1?+ zT~t6#_)IPh?=mVi^%KWT2_#1}{BXYW8E7L?461b{ZS&@-6ux*7AtdOe!ThA#u;@>I z>@$ zkGmqi^f$AJBImbRH*Te*&_AD9JW(tFc1Y>SXo--~GuTdX?>lmJ?PG&pdYq?K-|c|7;j%;e;e*VV>091!uq_;1_uwAGaAk zU?<_#h1q~X0fZgy0US!%s3waEfw<1gE7*WAY#A6B%$=Kg=gzl;Sw#SKcxp0jNvIGxc0gES2F#B8wySqcx zjQxNxXbW{l7rbr6o{2n8)0opkY6y_n}4 zhXgfu_wWd0AHxTdvO0At!a|5IE%|-zLn(cgj%?OG18qX7B+oo0fCxIY|5&y!7Ln}u zd?RX`4<8(L6)T5p4+$c)6EZeFtLw(nt1Gc=E#byv%jY{r*Md71Ig{Jo7BkX+p{ggzw?Qw7en`FRNO9d=8M<1>O^gUSi;A&JQ@qFqGTRkkG?7#r2&tY zhjRDsndkpPX!e*V?`6x2Z9pW_m&|V#;YbX!s!UW=)Evp9PKZu-w&8BW%>OI$%r$Q035I5pN%{e;jdt=5nOM^`@q=|2~y3B68aK7EzyQ(geT zpF@=j1vDKfh_c{ixuZakvq-bx0g_Jkc>QY*sFq6#ijoV_{5{`Di%BBM%~jF0f;JI| z)KGcINvBkUNAAdvD!~E{BK(7YVE<|HRyAXgPwj3)X?yWkH7$gGph11=cz%;fjb7oI zesbRpO9ejLLY-`Pmot_OkIAUO)2#8EnT!Pd2pgMN3YPq<#ny4JgtNAUU0fmDW4A*W zuBer=eAA`Tdf(yr3|r=B+HSJekrwOp)d$a?5;IN)DmWk zFga>x<68onSHTT_1o(H`ycyX)_lwR4hsAV3Vtgn5i{AqtARvc`A;L z%_3w)9J+lcEq+;%zJ1gB^$%-X;r6eQH|21OK9w|s*RQRERQMQ(S}Hnz{*zEa@F?D? zVac!G60!PxdttZbeThsnT*w!W5W0k03&N_>(|3Q=#J3KDJkhA)MDW>GlOqTxWPqd?Hk3P zz==5t*@R9nmNm)Q@NT-)l453(OpO~FtQncl8Hc=v%`rK-_1vz?y)&1A4g_wVb{Z&8 z$U>#80%p?J*PlAIsJy(JNl75+PO+c2q$>$CK?Rp5#m5&iwYEe`@`{Nya{Ltr$mkrdH0*$cZ^IS z5ycI^E|SIVq*-}yCoQj8gvhZjuuFJi2HL4Yf}^ULEcbh?`KhR=?(YQwNWz4WbHuz~ zc@~_Y;mH)9yxNN*WFjC*b$oQWT-#3ocryzx%fh zbM3>2MjW29;*z1r7Ux~tf8A2ZiFA3pKIHvAZm5T@#4#Ut_kWr?l;?O!oN0QvuDg$A zw6wHD9fjDuI?>#}-|JCX1M;JuZJcgl2R#TDRa)Pt?($fxCojP}vo3l=oFzIO>O2l9 z;bypK$w$|z4;Se_q6|e<^Kgo=`F1*;67!08(8()^njXKP2BW{weGPXCXE)(!KZr#@ z7EUlhkP2GotyNjWHgKk+v#j0C8-fy1m^atZ(79n;?3hKe=nxlwKZ5C@em<}5xq*CL$kah7}ob2a%87&^G-CtH4CY@*^m z)_`+NFib+p9%`}J!Qs#=9`R(*``*Mn6}p&t*X3u*S$&kG)m z6mi(o;Kp`_p#21cl4uSh^C|O0JQXEL!`Y#$*!+{~D3#7cK%xRm)FTuUPql~$sFXFy1Bc{i4Qt+CA%$d>x_vXqD?n^ z%z{pddd^`hHAlv>YlH|KD1bXb1cwKnSCotH-8wpLn)xw2ex3odjCG<=v}F9e&_GGf z9=v1ru;c{P%WOsnIIJ`((S3D`Cs!Q2z?4~4v!G^mv4^CZ4=gQzn)IXds$d%`%unp- zKdNzaev-s=FRu*t?X3oZ7p{?1_oLG@NQKQ>{$I~LP_;hEe@u=&;s_ytKK;&}KNnr5 zvNekJcB}~s-8)S*#DNwM%F}jO$8lBYfZS!^q8wDX1m~<*)Rtso(K7A81(!7hbV29M zpdX@p(_f(2>E8mZ-O-eS%r556z)^QbGnkO6T33t582H}@PjHa8s%!siNp$Roe&Rak zHAGA!?xd&hCV8aWAPgr_O@q+v$Iead=sz>2V7{gmil9Csm0eGCbM$%n2RO)m^|5<$ zN(=p&-?QlqDeLnqK~^JI+xAi1c>ak+C6q$Hm77NRz7X$kD9r}^b}eaX61mz->%B5Y zA1`d4@33|4UyopduUm*Gfjss}nPwj|)hkk+U?`pND6I4A0)rp2Wv}Gh^}R17Zz;X@ zTcE-pSM>|VAH0y3o<1Za5~0Aod*X+hg-YSY#$=FzYwLDfNewjk*-<&?{q%+yew3bl3Q$)K65Odz$3NxDDFR8i8&&{#4}L z3A>k001V$U>DZDf_q=&~oy(&RV7i!+cS#Aj`UQhpPLpARq@J>gA5!z#x@B7fmRe z{61JppHss5(i3`89L@Cwyy zu>l&bKJ7kr>kvkrKK)DCEJW6!Mq|}1LkgI)*YbqnHH5wU?}?l`@9)qhGrjI3dlP=b zN7^%%&B5Va?(Nty1at7NimOHhg+GMV=S|W zi{J$z=Am84S`?;L)E^NRP|t>@O&&vn08Rw2^Ut~?y?-?!6z$FKJd}+=GNFkWHQ)_~ zyTWW>iVpa&?%xI$0X;nGYo^J3_pSQ=&n|FN)G-$QH+m4%ia@zWbHg~XaZ%6xK}}F( z`(!eZ$G3J)6_|&t38qMrJuTGH`_G40WNoaMlr&|vNM3CH`tpjr=z4^Jysv% zMP3Pb_M1AR1z-Y4K~I&DL9iDM6!oyF>Q{|gUll`XW(;xo(o4}< zgT?t^mg=lqKkYOxwAD%aIk?2+@-b48s-IAAQPEISyQc=!udU)Q2o4Hr1w(19Qsvwq7*0=Fe!ZMA1417giGVxo-m+8;i_Q4d_4&$mVZx)&yVxB^r zoMb5Hw;#`Gf!~0)>HcLso6}0Bh#5^gtq3h+Edyp=7QRySaX2ey0-316JGxJXfhFuy7gzZGUznvQQmXGNxrvxy>*ssY4ga%FF3qW2n!tKG`p0)2hPF+wtTtj#s^h1M;z0YWU z88TH}QcZk5&G@B_Gi~iE0aoD>|fg%udYGi(n zQ&nvc%TX?H^t_LdI1fCo-KsprKbtj*U>|=`5Gsb?Ic}+=&s)rc?+Q z4SI9<`I;EXJWm$o%@R_{FaiECXYGS9d-m)%-Uo_5{Wxy${WwksU;O|i!OA^#3!BAd zM23ei*!8_g)r>+L&_YA5^}F7`YxbE3KVDztn!>pf?p|O1^N_U9VLlP@Yi6oa`)}

e!hB3W~LEmPo0uyXr>&J>p=td z;<7Lyi8BA=V7a!3aofwz)Te8=ySgo&C;ce5@n3^MX}t+jkMqq^MGQZ*?yfuCY`l7b ztV!B{@{Zjwb^0*{4PM3CKxfLYb<>5>btApAG88A~Zv*WqD+|2reZ;`D@98*pPpy^_ zY1uI+uW`U<@#uUPhgpo;#xF+_O=jjgX%R#o6YfgnlH53NYF!VqK!?ck#%28#QuOmi z$k^H5tljg^LZ5s2~$an@w!C2p>0hSGRU` z<0k1Fv@-Y3 z4C0P2kPt5H5Vft+4owE6Nd9aNH2?xO&-vrH!2g)H<^*>CY5&-Yty^qB*IEaQH)|6n z*!FDzKCc_WS=f;dRp4%lY+#%HK+yBox;r};T z_J72mzx`({?>HTzBoh!rrfefpZu#N3MOOfh^y%Ze|faVZxgeOE1n7 z>v=P9KqF@d(#(8iac;u5C^h`MZPL<4KMkbB(%yM=Qd?^VsBwD=w~m9sC~@uZkrj)z z%S{sc&i`+t1^&<8(B#sA-dY&UJrxb5>5pyW^ZQZ<>a35cnk}by;e6~%*%o_Pb3b*~ zVHC(0)H-(_RsGaEeY(W_h!45{sN^HHx}etBIoo78X~U{y462$*jg~N3E|DuOQl$Hb zlXZmZyE_?r(n>?C-*p`6FSLJ+?dwPtAg2J`*|v1Gur+peeO@}h_`Lmv$roIhvW|j{ z!3^prI9r%4(~~*ubcvv@7TC7CcTbwC((JX*<5%P-Vdi#v9eQq|^gh=V9O8#Po_U@X z#y%0~L<~ddFeZU88&2A|k1Qp+K!PUBcsbR&fOiQrB)zl%DK+aAwEmbt7qMdX1l962QO8wnV*J<>90oLR&b7i<%Yz8o72oG@amZsro-Sxdx5 zgmpIu>7M%sEJ;#A0=TLpqsdSCP4A zRP^V6Oet|dycp7b1{BNqNV zyGO@RT9{0`yLIDri>_CxSsy)gjh8bFPnbVTK&x{5MXY4sfXZ`*od&fGAfZV2AYO$! zlXA`ezyp-kN_3St#)Kon$xxa|Fkx5$%@b5e*m`Y+Ew5hNB)|Sa($aJZ?ER2P#OLgq zl?K{;@nW#kYl>UIu66HCipL8okoLAV2oZd21qMIwo2z|kJMX^)=FjFfHaE1AD8M3A z9mFvlp-?iEMnt&iau@juFFxH-5o=b9P!*hSJ21T6&Kl6}{$jEFAEmo*5ECzQHIlkzVCJ9c^zH_)L((n7D}7K}?r-|B$!&M)=km5UADiEN920tdNCZ~g35&)( zRy2{3P&s&Zuc*THmUU;R$!VGu{kcHFwXQYeluXWonopv-x_(`HwLqk{@uO}?dTrT| zTE54}hY3l9rrEwHD8xpUV`)4?AYVK(n029TO{PC6I~nSZ+Fp9#7NNg~>u9<~huR!yq#2+l1>D zIDu9j$=mJ`NLOx0vKt!s?;nm^X@vA|EEImswG{-5_0W;kP)~R}&(7JkG|rL%@!0UZ z^_EJD}#wVIEvaL4Y~PyYi{>g*xkPOTiK}#*VYZ*_Sf=;K?QGi zehIyP#`|sYu5%wgKkyoMy~{XA#cNFG?Uk#$ZRdSwv_{Az%L3bmK$C!e<!Y(snR!&Gn4oc?Imk9XmvqjLz6+GgulhV$#8PJb8bM}xzYN|Tm`Yi<9`5%#2!YBhl6V#CZF(@jo**K%3 zaa%e=e3`Fa&lUJu01*mOK_JYfVFh1ou`8pZ>c_t`i48DDnH;3Ujxmp~cvf`vJboKOp&ahs+x)!udGLKo_mshZ5G1}7s z8%8o>;JI_hi0>yJgu%@BOKH*MC8jXy3qzTAVE)}WnIfL7BZwA zF)$M4KR+VUp)F98WnAg=jZ9a9rBDv|kyzCfpS4d#j1%i<{zsx@RG+e-@W8t)T0@oy z*&+LlJYp_gytrY7o+CD0ZMjg3I*Ni zSzY}`WexfzWVlF9eH-DpujIyx1)&{AU9XZ&L;TFrLHJFlN)&c=ECu=Ep)4jL0nJ?k)5 zBK90|X=oAyuVPp@tG65Gjy!M9cez8fo{EC%258DqnobSE%U+9Oiq$_%0PIkZ2LuGn zK4%M0+uZ|4c(F*Zhmb^pv}8);A>(mkuIye)`o499G;Cxx;YNAUDXQ!ntb?%%F61j)3$qe*W{1GV%oriXN=kc<%6iFVmo!CL{(y|zR zKh*vQjX_oB#~ww2+&9?X$=XnwL07!@3gPtPf;wiw#ei6*>2ale>11R9WP#VjG$_Qrv{;Yjo9Y z@p+y{=OBv&_RY#dLCno}zWZ3~4+!<&kIhp-05koxgrhVs7J?Ym6i%3h4YKQbRi9lgl9*Y6ik#2WXm5tP(& zOCYiyv><#kvATIRh_Zk4q=fFapJ~7h{RB5;f>{IbFuxnV`QIDJuX?c zY^Jt0RXs)ja#^ArhD(S#=yXBk>u71I$Q>GS@DY9^>^^@k`p74aR!|^#Z1%Y<=$S+- zHsSqCkz6zqT}x@Kn0Jm22rC+()2AKn9UWEK0Y70-h{&rVT6u!ZQzlqYZ7PQ^=(szI z`MlY^ZSlPH_$@vnl;EXiKgGKefBQCqw3AN4)FTX`r|hsvHSve8+Gl1~`6d(W^v**p zNQ=@*GPEP)9}4Vk*Ek9A@=^wXxwd#yYqiCtRx%oeAI;2f!l2!n%8Q6aNlF4IKU_l> zHg)FANpgo`r-x7;0M*ym3zIh98OY;tI%M~QMiv(d#F#L}>C58a_Nd4z{=Y zw#|1Tj1Jeda6L?-0s{K7$p9vf#L_D+{}Gs=mpEL5pBhysdqLEhGjLP!ISt+ss*@&8 z^g1^zz7#nml0G* z5U>{tLflKgc&2?QtE>z;(s;WT!x2)MH2Z||25EFG%yVD@GzX;m2~sD>c-D$^uj=1d zT)r(DoQBzjc#^@QX&D3%aju;KQ@=`}$aKMQgKeGZP8?Q(kfprAeBhdJ_45OBw+3{t z6)U=L8_(MuouPihwRClSUX1|7^~FB5cJ~l?{%HqiIcfl$+$#v>gG;C2AQNdLGEff0 z$M1Qdf?M(S&6~l)sc~{f1%O=BKT*ZCx8GF}Lm0Bbnk2O865HuP$yBS0k-~l$$v~T^ zO(R8UTIUt}ryV(4P#bH4#=3#IHwxj>=YB$*jr#EJeM{#u(#r+@M;JR)-?lZg?Nw92 z&u+9$>-NYUG{*L0^wdMwmxk?HOZDzb>H|qh!Mb3%c36IUkh-LIBQj7`nc)YGTfA+q z7E?hZpX3f*TcN9JnzKR{i5MmO`LuDI-7*Ka`qqs0$c|Da&{f1V#uB; zIO>LZAXK49VCxWF^_AL+MFSFZwHPdcUIYzaM>oNplOJd;6V>f$RIi0}wMT4(fgsB`rdr6qi+qZ7o3J#Ly z4Zn*F=${s|RN5HymrlaeIL>nZpoGN}H@mq}1tlZ}(%fBG2IiAamCiJob?YkH+jZ{L zq3UF5GsV^Dr17{GS}ezDG({LMDc#N;k=p)4dFFO8|Fv^hmne_TM}$E{_LCq!yHSrrU9e#s()XYM2!K*#j zw^MrJ@52QGf#?U!mRwO=vlO!t81C%-8+(H8;Z~nagwNZy$+RaYCidTc*R^0l+hVw1 zT%;$R7!^Bl;zXlZX}zKCj7L9zo2h2$XNJMJa=>vohh8#Q^Ukbc^idtrlDR{N=0mLxFb(Qca|+{i+uL8%YuTIVN{(hV2(WlT$aW5Ssib zI_S)f=P_a6?^l+~LS~`IW8uDh#yHB%8qPC{6UGaoVtf8C0wP{cBFGZ0h_7-vBMDGx z?@{_L8(T)C7d`Fau9&7Q| zi_BlOHn`-y?rk0GLBFm#;Ph)4L4~l*NQ^{Rj=pX#`#6dCGKw1U00)>ux(C;_gdzvTW?C4 z|Ee-Fq+X=XXZfDr?%Pa3Gd4Kb@bqFdudZ8*e=g^HVoz{(7W2U=b_4#E4Z&2EB@Okr zZ^@DFXNY@lmOfdZh@jlH=Qr6)WMS9UvN60%68RD~h**Me`>6(x{u=L2qov{;tnbu) zxc{~Wn1IVKKI^EOQH#)}Jvf5cM~GXcSv)^sz?ah(u1(6kswrkf>_&;`UX=sjiCJ0x z&D$Z{SO~}*QPP&UGPH%?$|K*F?#n?;qh2L8(seq49=tOZ?prr+rUZhTa-7D?$U>sk zED-ly*&U&Qv%}XdWSEr0 z<$k9Zx2HKHKxT#+1DS9Zr+p5bE^*FD%G*VEtcQWmR_j69?LJ}Sk)&n+>}i@HV+L%+ z(d_IT`1yr}Iu`0K65V+Sii%TF(}W2b2bEtoOo^=7&Gas#)X;%mUtLz9uBKo{x@{FE zrL<3O1C~;c6nl)uqVwB+g>yV+Di+i)=6l|lQHlz;JS_bbyN>?&QELzo7CTV}eZfOP zY|L45z<6Ac3cYoZ3NH_Ic)Oq6ty%w_a8Oou4QI1xu$2bq1BD@r<@xit`1o8I_hnPn zxuJER6#W7*NZ5~^B7Et7sknLKl|DM^Gu=TNN0~UFAY>Ay0afy*`=zX_t#YZUg ziN(dzhIGcl%o+30dzyb8rmJ|9zA`3SMeMt8FPhl&-K|$D+!?JLs&t3$gRT!{v+(Df zyZ!eBmXS__Y;LDYi5)K>w#K1fu6`Q7C&62fu#q!SQRv=2d>s-*emQwvK;g|bGqYO4 zsn#rG25j*eWEiOCRaJp01W*!}I_Chavi@@0AN|@lNjYN?Fi zfyi03ks-&}#0lyk&FTz7j$FI?>AwgO#ydn@Te0 z2jqD&eKXnmz++~PZqYy$NALT`f(=l(-g+L!F`|!t*>BX}J)=ond$Qb6jb;7n$A9sI zEFoQK-`&ml$Df3crY}BfjqHE$EvI<`QL%5x79nqN^6{rP9PW@kbb2x6s75{|6|*+~ z`U=Q%?X}-qeU~$IO*BG5DxE#OR3>TP=~}4aDgCKW`@BTnZ_hj-&m^Ky>z#{})8+OD zbfx)8M5QbmBdM3NOv>k#j7a}M(`TQXB6lcfahmI}XSY`uI8Gx%f}7q&kRwFUSZ;Rq$2U^V_75EW4!~I6f*X^b*e0ag8&|BFj>&l8>uVth9 z3z=}sel;4sxVT|qhFF~|Pu?pKDG;Vhvbm^RvQre)2LH2!n^!nT&TDS4j67&e!dCmzxn>cx0W`cc4=eWUA@Gfj1O$ecfI?ct z@WCCOv?e?C6DrQPht-qE^l~n`;>diaF1<_DUQ#X$`W24hXhI=+wTDhn6cUKPd~r;+ zrfJ%i&hbb}_LWwLcd6^PD{fD>!6{?y!#Ey>?4XVrKIpA@;FaWrFhWJWLsEaGC?yug z=2Xt*jWi-AGkfc{(~#*YN!b)2oxQx`sps3yPVb##rG&jFdB?Y3Y@DlS7#LJlR1Ev& z5fO`ZPf@3=QZE^m-D$$9swRscfHM*Q)74_UZa4_8Uw_)*MTZ45bKo-JE`)&VDN!E{ zt{oGDVUI4TFJ^$G{H>@4vneaW~z?4Y5m#>yPT*SwnnsN2&OMXr!a zDWNua z8cQc0mpO>Z4g3?1r1d-CP>0T*alHxZjLqBk$&=)$Coa=Y6ZyzI9Vs<~K-X}bT1&4Y ztB>h<9n+#G<>DRU!2eCq!B{^A=6?GgG{_zX>ZzlQ2RH#kVYpZ zef=@O=Gx3=Z#}A?gd3(j7jV(={yj4}BS3>h`aj%ZRYjVF83}!HujoMmqxWRmrk$sr z5u&D;@aJ|nIr%shtLEA2hCOfh+_|}Q!zinvWSQm8Rxp1^qq^(VoUvLBnn;mKfFMyx zu^f^A*}VTlfUl%9AMIMEtLRU5{gR0z@{_obpWl}K{L@rbrDaIZdY3a~ZSNuRa;zP& zN-xcH0Af7{>vp6_pr7wNtZ2e{rDh0vz7mDJO_Jg>i6B8seCuUlS(#YAdCpq}tHJVJ zzb)y@y+EU!$?u^no4Bun6oVHCNj~E?D@5$DyGXl&F;J=Hr1>&U=J&dNDZ3#-Ad2u5tIQ>3CWD!RJ8=EB__UW@y1r- ze5RM|9$`|J)x#xmmbuy~pS^)x3Tj#5>{BGI9aY_EE%%;TmGXxktF%Z>9xh^NZ_Oyd z_N_2sWbPY~O?5=S>snS8nRyBlA=yNphMv{VPhx7B1KIuC2_AKv_Wg$|vP)cX!U*P{ z1nA~8c%k;{tug4`K*-|6Jo;xo{vLDYyf$qbe0#3$VF-@~=n><2JI4~TTEgjWyvj|AW$fRz5kN@9bf0h3fe{`4(a z5h+J<3_aI2s%uTVM8ulkx4PxX8FRImXLoMt{JkO~#~4~fEdUKN(Vu$7>ToAd|ENi2 z#(Z47N(CQ25H<(g#MZ{`-7&lFsQ68(2aE1@-%TcMUr3~YolsH=F|i}jO*D~uh7H}r z=)6nm%a?zaH}FQIX|qr-BCE|JrDf2!k3Go3cMHp9^5R(YKl7asCCD6P)*m`K_N1}< zkT`LGbEM(!Jh>=jfs10EA1{$&9(&*X(HrV)eveRn)CB^sWtTuRS3m!2np6YvC0Q!c zp-vNl*kivdXgBWPz;_{xP2rbcUIWLB7&^9_L9eXATV6w#?Ukv8OiN4CrOYGdvUE>L z{|DP%aadG2%rkjj^JEBXr?dH8b@lnj>k;5m5t0*o)K9hgV&#h$w8I?i3l=WCc_O1s zM3}y?>C1jGkJy0^AdwJu?l(e>x+cW%HVqh4$07bK>c6Wjy|=4Lsi2idUwCcg$%L>4 z7ts@F)CutZ6IFR_wP&OEJWu|RGH7f1BuROLiYcSr3MV?=G^94hoerN zB8$#Fb-Kh#4LMBXHSTIywctpoBmA~tELD}E^MtUMtG{vcCY=p;>K_NP?QLvst*hGE zw3eh4>^ZUhMydC=m*Bld=6JJ&(N%em^E|$I++cv2g}Re`h^m?H3xYtY(GJi^{{WT= zA3YNQ9c;L59U0Pv>C6!@S~Mq$=&yh1niwmFtnun?D~)Z6kt9IrdG~V~Yrk_RHAP$7 z&J#iyP3%-_j_7oXnj81Ltso9RGNN%8*p;k!PO+bU6{0pI0X!1d;tM>cqKgY@wm_lo zt!!%jE{qKdyFiMf)IXl$K?w;7M+?X=Ht@Y0yW><%8jYc;Dr=kvx8?Z15Uz7i)8^2; zn#}ky3l^3eziQ6B+!iv~5loNK3IWtuIO+{BuYY!cqoX6#;PIzly5?U~mpjC&UlZr$ z=LgsGa3M(4^MV4qY=6igO~!t^xjh(;1?WhgL08>)a$N4;|6}_<95a5R#+dC&jG6o_ z89GqTT=(y(2aTl*=k0FH>T$oR7RJ1FH=X&{d{pgXD6=iKp$7+lXk{B(bQyhSQ^)Ga zA!S3267h~Ny4)VO>{?$$Bi$aN@3}#5s=)mYJhT6|ckGct8>fMoUi0<4XJ*9hQP^$J z*Z`ZLwNgE%AUDJJfh300nH>xP-g>O_Utq|qY&c51wt%&Vy(I@=eXS_vC0 zeb8u5LayE*nKEi>%h6BvsTpZ1kfxlZu{dR4djFK4CSr5v7i!5WJl1bq_GPd^XU-vc z1dz+)+VmES^lOcgOn=eTYTje;Cg#f%l z3SNxOJ)FDMt21|2wXxqB`?{4AA4+|HvFw|A=iXk0g&(EamQ?S*s2`F8`1qCEq<1JWLKY>bR<&>wCTv)NB% z>a4;JbB(#5_X65=t!(+@ciFBNr#EHH-&Z&_aMw;@Wt>b9{l9$T$JfiIl{+ORHtxj5U~L@JL?Rn2K^_2K9Z-9k=ZNJ1pq%$o~hv z@BiIjHe9)_Upuud8fQ?YgT`Qdk@-V{JYEJH1?($t(207Ew4b!~rjAyQ_fMf$UTrM@32h-)dXYT-;2vzHoG*7tNu^|e9Lt65G zJYJ9cW?aQZrJ7dG?rP=GA&>2AMon^_v-9&ytm^4DS!m4FPt5+>Z*{D>@5kTlApdDb z^kXz*3bIvdU8{DXFoze$AdNnUe1P|V2#0-_o(~^*(0Ci13?UpCGvI(CE(tX$q$rZ+ zN~GxS?v2;-Iu(V1aaO4A5ae9U8764`=+wm?4vnbXC?Dc*3h>fgO)YHp=XcT^DG;y% zR)D@J46=v4pbB_Kf6hX`z&{hIS+Y=qC;ufBk+ROvg2anwZvu8_dMMty`5#g}YAL2p{U9t8vU! zva*aJl)9cl4yU+jz*4XOkoNQ8d(hg`Q6k^Cyn=xrp3${i-=>j_PI{Q0Qh;I7R{;mD zeMtdE-H4e8Q3-1d#{|XLoMWaKju_8({ApJ6eogrKyR>DgsU4)55K~){kxH5BAWs^4 zF(NyhD^brI7X@Os2RDu2FQ}URjnHN>xx)OFtHp|ts{21zVyvZ?y1XJtg;RhWMY9>v zwSe@u-}hHjCyPs~UyusB=$|pKD*YxBz-mUjCFBdBhT?w#(vCl!7YohoApeA{fK{S` zvk(hCVSek1^p+4fNvOw)>(xtV<|n~zq6PdXW5spwWw}ER)35F3BLSo%^BWXOvV_9{ zCkPsp2vuxb0u_0D+1wJTO4ggC03$QvuqG{Y&aLQWI71z1tU2Fwq5w=Y{<%uI_}nyt zuxG|4)}*=VbQ^ED84Vt9p&`e$(xr%SH=xghbKom!rZwbl`njO|JZ2j9c--0TXRbIr zqr>OuIMTfL!)_)CnLbI@Ku5|oEV}Hm6xKOc3k3T3Q(`0z9vs3js8(N{sF5NpO^aRV zuUyU$?+52#p`Jq^p@tsW*VSRy{#yj>nh;3l8^4D0@LJ&>4beTd1$fIBlU$XKLX2-WX&Yt?Jcc=n)D$cpS!P z3>${Bibz(qbvrXEz@K6Y`Y2DI%#^VQEN?Otfmi}0RzcU`6Dl01QBhQ}yD$kxD5tkD zi!WO5lUK$t(O?y^9Q$UvMSd+__aj+jH1cu(kkgG00uLj#oR|pgY$7elj@BoSgc;!* z7@<_&z?%sQ54OE++cr;6&r(ScCENHWsh20f@F{tM~wwf0_rtqPB#vC=)Sdtuy}g5fRSF;-|uh= ztP|aUy_W8@X^d*x!uT98CLxKZi7=zHzU0;jgy%2H%F4_%5PQhzH7O}1U%@#Lw$C?q zca>(NCNseP*3#k`;!sH{!J)@V_wGSiGEIx*a^i961DI%mUZKeO3u*nXDNh}@LHZ-F z*|-taA8~l3^rJ2_ltx7b#YXa~e=^U5t+F~`~_-s$5k>h%W=vbZ%#<`#@`5Pe0 zYARs3O=L3qz7*W-5%PvI=n@Eg2s?S{<7t~BioYDHL+SDfeemO2!#lj>|`K8gg zxt=JhU$b^C=pn8MmaFvd;kN3S)fg3*US@fooK)+?mT>m`6L|0y?Zf(3RF4F zUFOGAuhm7H`vh;*a2(P~XnC276qNj9)lhZiOF$Yu{s2cE2HWhY0kLm)&VH8mOz)5? z7mtT}>p40N?EYcZGz#lLFhU=aT!5NEN9$mr_*#?^rbcJU&jlDqeL2L8*lu5*UVb!|2Ckkw&XM3{?3 z&=i(H@od^W&c`31K4UzGQA%1R86_kM)OP@S_Kd64+B}YPGe(Zzy!9V z3CD#*a<)-qY;aS>u4?k~@>F1)gixpW^rQwRB-Dd{krpDRaJJ$zo3^7GnyK*l)WH2( zzI9h)gzHg0fp{-=k%1ep^GGh40RDs!G7Qem$wd>6{};C&S^tGyGh)^bd*<_X=FRH4(;fgf>9RVn!uk;?k1^GX1F0qgk|emY={dDW3Mn zk^__tylFR+$u|lLG=;EE*?tJ)66)ac;&X2MDZlt}Gsp%hXN3JR5wcB{E-nhZ>^?IL;HT((m3#!R#I z^7_W$rvKlVsFpl^`gG{NHiUU4lwzEg#!FYgQia>XpC^?u9)>#C)40t0 z){g!7i@>lYbyFmNn`2T(={IRs-ixg>{cPcFtW2uf9Xo`}hd0WDjpL->K)zGM}JNp5%`xL1aMmS;|~Oq zaa$;tB9Xr@CfW3napjrvF1snOV!v^}<+J|GQL>;AGNSTq|{QIyY=XT_c zEs&UapD7Axtpaz5TgPo7{?-gLh>T%t+w+?n|pGMNe(046E=-b$|XN^cPRlA?O55VnSv;{E#@ z>6n9$v!P>Bd?|NvX{YVVJ5MkXqLW`*kGWyOF-GGHtRHmWasOnh;*SZ?i_qrZ zxi(A5N8O>TVEM$@%JTo=0$h~reisu&1xll67LGQb9>JmtxMSfFgF(fcQm&~FmPl&Z zZY*mkl7FkURO1vN&u0^JDat9&VZ z>X#%37(G~~t#rDA^jOu?TRKiC7Lcqd7ie_&>odpQ^XE$l6MenNrL9eC-?i_>FWm@} z5=>X`Sl_{04W&V8`N$)N9+?_+)C)V6+HNgdyGOd}l6^SY7j3OCt9L3EI9rQz zd(;boLP$C6PGyXxKBD)8Wu6tqPtV`fB5TL~-0p8PSYc~RM9slqDWH^*Q`SKa0EJ-D z-ZgEw-)Gm4q01&5H*m`7eqsB?ZsE4&asVHR+oaB&+|hF1Njaw{u4=oolG67#BXbs( zYRU}j`Q^B3Mq1;&Xdel^`u}HQ{UP>P%O{l5Vw7c$N+nBG;y|D0 zzX8W;3lwpZZ5iV)hZ0A6@LG*>w(FJwaz+3+rwXyZk1><>&-?4M+z~^bP+E?%9TZqH zTO!QmhBKwntC)_bf~THx5vKjqjxF^3TA0wd%rVfn+5iH?GxV3PgFq#2D^2LayZG^_i~Qm~tkb ziQcY5DX9>2H;CRtgPWdAUTln!6*GOuzs%I`b?7XaQ8~QkZsG{$Nb2gJfi>0uNJjto z5>2VVVXr0^<>t9EYiOU@oBV9|M*VnYbil0ilHi#B6vAA#@-ws!X_it#_v_Yuuhcn; zx8b4o)LVaTHJ-KyqO6Mt$@L5uhzsY>_uoDj;H2RaQOE$A_(yh*Ic)>fSu)3+)2&;7Tj91JPz+6$KqBVKS$G&PxHTnBHNtQENw2ak@gO-++l z=(P6pT5bEQ?dtad7}F=G!!y`84RPOZw@%TK!R41U*#BgJXy8#5>GY$J{nhhCU~HH} zb-0BTfM(eYJW6%Mvmyja4LY=eEJIjUEJCOmu_$w6?#QkDcIZ!J1UIcNi=vmQkI>Io zP030~d-*$VTEZ##NGHtBc0zT6$?&{PnU1$n77yLl|2L?D74_0xB*VJ^0HCiL5lF8C zMI9rVzjS@YwfC{Oq8enw*z_>Kr(T0E0$6*0!!;)@nNeoARN<lW%IwqS;K zR#p(&$mGFX4PvG+TzpkgF8~V1V=T=d@$I|Lo42$VoHg7HfDwFX$wSJ%FgyVm0Aqz^ujF=_sp0!R^68WQ+R)cFM5-KcDh6V1-IQw{$?UmJg z_M{odRlIt2P-|bSyvjYeTv+;waj;*HKL9HNP_g%%siU)1Bu~VPOX5!LNXm=W?xIJJ z-ml4EyT(XgJ(igxeVENVX^iD244BDQ)i9YIEO&Y)i97PvG0Uv4vUe^fIFJ}qbC`wNl72@ z;YpS>w^yt)`u#*MeQR0LD!4EaAj>#@jQr-bwXNI`Njde1u+T_pYt}PFC&Nx^(<{G!RvG!UhfSm7T@J15P-qSJVttE2LF%*!mZ&#R_R)s>=UsKW zzy53fdWyU1Zy)vu;mVy7uLn;{=kb@xqwd~2JE|Vj&6P=%i##s3@}+tPqA8 zvJ`BBcP4S&we&rkQbui1IeeZ99-k(mp(JQ<_gTUs>)OA*{Z0SN-I3YF&!5MesoA(H zJW~*y_2@TOdChw_)Kz!h*BskLUB?QxS$Zk1+zjYJ&txknh;EUy(^QCQDSxf|?52&cH!UTQd%KSJr}vijmj(w80o-R;lv({v@E8rZBMFn1 zflhSDkh?E4+WNxPNgW>z@!*Kz`Y#VxojccWk`x6pY_(uKTisXX`Fm9^dOzmtm7h3k zAZM_Hx3!5bRSX_V>Q+E(P0Q=96JC{e?{$l4rwOP^N?MGoZP5L-hX#pW`gRYy)IoR^Q3;=!Su|WKNQC)?d+?pv+JM zF*z#GRVoVc{`x2U!!GwNo4vryg-k6{?-dIydYg7B@@(*Zv4pq%7X^vvaUs~0r{3Xw zTXQxq?ka#IR$6(Yzt$k!;@yh>60^~P&52Zie5xh~-+2F|KpdKx8WhzN)ofb&9|?HH zB=|-8`q4`h$TecSFdTPa7v`cIXEwiIE6D1vQ^aBapNa%TQ%HP8$VBq;2rpEOMBF(i z4S#iO_ZDYobJ>GU^EQYbr~qgD@@z2&+&u7goEZAGUbDY%{dBk*q;~CDqZM5o65Hd- zI%1>~T*q*&^NvsVn5GfzL0paphcHM7m>!-d!dv*{qL$I>@W07E@ALdLBFNO7VUoG^ z9cy#;dAYe!<=?Gkl4z|Vwfl}plpf#H5ElZq9{xEj%9f>!4{BEI&06<)F7+>h!G=F# z${cO`%dUR?zC7o9>bw}{b?^?KSEWNihP{D<#LJh7RV|&Jnm2MTe;9FR&6ML0NUKL# z61Tr6`ahP*+P`9|MVG9k7MtsQ_(O1)@LGDr&)+}EQrSQji`7@w`)*Zz6J4!JYtN}} zKQ0W@FM7Ozj{)FJZSByQ+c+y!732YA|e-7%;b#ikN9F9?A;kACOR} z_qn#kG&u`-zA%)=_93h6R`H=UYYlyk+a8Jmu`FjJb0mIe)aUi>-6%})bhZzFW5wC_ zP=?it1@Fm4WA7lQlsfxp*-)!p?CA|oUoU-(q$`UnY10{6DjWSkZ|HzsV5>7QAbyJSeax1gf*`|C_vn|3lKo|78+` zetA~+wyi(fGP3&sSUvmSlP@Iq6^XxcOwkDOjw7J>{~LlxU7zj+1L{($#?msLUU7Ya zt|ZsTd)}vuxv3RfQQ$1gDrg>&Qrg-pH-Db=n90h!yEU3k>x13prFHcwxQ884>{7$5 zX-jmUG{Oil<@usH=XbC0-^L#dKDAiVx$#25{^DJ`Axe^_^5X*DDLw0ZCDS?UvRELS z|FeYZSUB#OqS*5(1dBwn=41xw_r|b0&Wm0kI0#)UR*V@b?P@OAf631KKg)gjzj_(} zAH664_h0PZg5Hm+CSuNo7;Iqb-@APR@gm`XWC(!F)UZMp{F9HMr?mbsU1|Efh98?- zpZL_JKk+!wJ+n$zcS}+BhE1ZueygSbQjzucLq=XU6C&h8k~r%#_6w@*hUu2gLO(A4zNJu5sORBCvjk!6^&o;uZ< z7*LnK%ZS3X6r(o90=RVu2n=MucaE3hulw{VO4zLEL>G$zDIHau!^@oS%{e$f(BSGg z$pqI?RF2fJuF=}glu!RhhklT8op7)z7BCFJJbV9O@vCS@IsddBgP4ueE=+^_F25_pFdQbSZs zKH|!ZS%M0~Zfx@@VhD^7$iL(R-MooQcl3jQ`UUoOi6==E_bn(eYbXGteTGs0&7BY z645JT+H50&JIm=H7ocpcB&S6# zaL%d5^c7Z0eFu{f)8km~(IXd5GYKWYaVMJkg5wYVxqit(f#BuLZ*!`@1n1hwD=RaN z92)YfwL9eKQHET|K>chv(TnOACzPI=vK|_gqydwQjngWK;#sJQQ40f7E35F+rzv7% z)nXr!l3n;t+x7hmI*u_zqzq)mp*jMaVb?AUNnD5BCVIf?GWC>=&AH6X$_>}fjnvZ9 z%T7;!G5cwYE1nVB*nz^i$lQZy`zPHh8xh{2h&6C$qKf$j-$E-mBq#`Y);((<@^Zwg zWV>HnA?V=s^fUru^EWU=VKRAb_0f7aBXitg3Y5KXr>8?=EqQ15(~-&+22o8dEx2&V z`w_EDS#`tF%)-L7w0I4~4(1;N1`OCbb96AnP>LouPfx|ma-P;he-baNqPG=Nuh_SX z>PH6+C~$_&jwbFGr-=br{lJ`$Z`?W)lOod-*p)LCL`+rdbX7f2Ni*Ux(I;7Z+cC%| zX^6y>@PFI)B43eG_H=?kVRiXmF-xz&86bkf8*hIhKFRCtKR+jwPAuH1sgULUBU{yE zQ_5_Tw&Kjg<53_h&(3LZ(n1;q4p`l~VDjsQq4%6$ZJbuC#oPsc#$yanG!;trV)#Nt z7PbJMUp_v11pk|;9Q0pE4AvpP6MXbG$P+j?X2=2{AK?RG54vK~A&v5~GH4DLQ#Ka~ zGzdA9(D3+-st;r#ali>pU2kvib@j@vH*vm)JOMKR!py*Q#R4~-C={jQQL;!T&sj$y z2pOD@3K~z#FkB$)7s?qyR$;J~^}VOL!P#ZDe_jc1)%&Ss&ucd<2$B*XI%D-!6tN?l zl;lx<22%(cH&b(Y`_HlU-#?zBMh{iZT1UF5_vbui9EG0PicgO*7!+gDkiM&a>CA|4 zfBv{GY-2UeH5F3U7FZ28A<0fm)lGkLwEhc`#^iojSkRL*EmAd+_rrHkkYT8q#X7kO zdDRKaa6H#ko+#~uMNe}xX&D6}%5?Qi=TOc(-t>#m9|0DDsSS20 zj<6v|1TtzJjfty{5{WVqkC!f4av@P+i2pA~VB&zfz{Y0MS>VhQLi-Lh#U9RN9K6R< z4>Lh%0QnnclbNetC|EKzb!i*O%H*9hH-GlS+uQsQV2u4f1zj-dfI^u(I7PG_Kt|WE z&oif&OW?4gyQ0^o#D{RRSBD7_@$)>Gp{n0Lp2}GvcRnZQ2cOQbvSoykvx9@fo;|1P z(!YP-7nqM=4OQh1>ugUDBE~t;OMaJ^2~|C_=sb;vNM;Tm(u6^|XR#WL(msTI_87oT z+MgV>S=ni-3?M4LX@eyD;rEPMMm@|u>x$Z-MbX%=5I>btfq3_y1Sm;4)%qe)F&)aq z6@=^R>FWM2?Ni=-(r%RLX69s&I$)9e&x={8`n0Ufs8rL>mS;03C)!qCCJ$1hjuI=d z!P`I2HdGaLpR&xl$D4Ms(D{I%X*3m!&hOO;e(E~uWH>>2QBlVI$-a)`_A|8Ryx6h@ z>ubJjN`*D%IxANa^8Tk~7)~x6F{4H~3nZH1hs;+rq1eK4AR&4mBWI*)d!u%wRq{Hp z=t+tr#R-1Ou2(_L5jatw* z_pc=rpoa|fIgORr*|Uay)5;vj=(<1Ko)!&=3uBI_6V4|pUAu}U6Yxm<0T;|xKCGyo z{b~JrU^_ZXawtxU&!6n{`f)YT(Xbf9nK!@1l`s=~&0Du`lLAHI^Wj6hgFEWY4%2UW ztHX?@r)mWg6u}}*sMzW#P}jxA>abTFQ=e)Ag~ELmyyKm;iWRRir*CSSS2Cz9^t$?D z*@M7xdvzoY1hsuu)^8y>pm(=+H$w_F^rjflDMD!|lL!BU+4IQ<8$~ji9^D!R&Kzvx zO|Bzpg5fWuEgecUpMK`3T7Zwr5>r!)NyGi~4k#t4EwU&>&511`emkpHC9ltkDtuS$ zRasqK@AkT`9fBa_MP$0OKM;Qgs_+5Sqxhv3wk)3Qm|9hAPea^C0q@nh7s$4QNde;uoFxr3BTS{! zg+t==a-ETVS6PtCIQWo{igd%4ZdSd31^}gL1*1#=CD8fM$eVnU4pnV*xy5c1*R@P- zMa(EmKU;_;8mB0X!{ebFreaWbJ=dm23n8Cn7*NE+bmC6q$u~b-hnDi+;}Tcm9k52i z*i*?l!+Kwiauy6MhPMm`OF|Gat+_lxV7rnFd`_t!61S6@@HC-Ru-*|>2?iwhoo|Ct zRiX%CZjHIf>(>HFyWT21p1KO7$M}U}KHpcx?)9>gH}((x##R|!1f7g66+Cv*;9SuX z^ApQ2cEX{PO9fj|P} zYuvkc58I#Mx>z+=4N2h=QW%ZwJ2FpXT1rx|uF^6!xp^}#YMe?+NRZ0A3Lp+@v?_YP zt-O)BQi#;$L!jDhC@q%;<~hG|Ox+;ww~w6_M3%bi%o2F8%qaHkNlLth_j3|t#RyJs zMa91zUcTGu5bk%=q}Mlhu|?U^QyMaPV#6;!7G!PxSm`T6cR$xN#tGewK~5yxZej(H%Dt50XO!*fcnFSm2%k3#N50T z;sGdrq5lE(LXdfzm3gAAd-jaXyXvAH)HH&D-|N?}N!0`0>>08)OvwJG@(IkpTkAye zB7Jvz8os>H!|XNnwzR=~Cc{TQUM)bqP-?-t%PY;7qs)bA!S=atpVB-1gi^$pRonNd z$&nos@pzid^x3l+cO6+b-sfX`ZsG1XGYmW(dY_uh8Oa_@OdMM6dLp6Zhw<8xD=fCn zj0ja7qr==j>kNBl0afI~k!PwgTd*Yd<&5-az(lttd}KZ*6o)QTWz<))#jSp!y$xQYS5j6&o$FX7lh^{ZE$>ZDsH zCX#aoQ3;V!W@g~^@rfFQ7neB#4(Qtc``KCSiyj_t5=yeh4Je>(Q%SYhvu9O8DdQWC zTg2}v6*`)lgs<~>7j@9xY)>P0j_yV#&ukWseZ^H(kwx`$-bu>)T~ntHv#_AHx}Z7Y z*PU{Q^_y$(anrmk46l7A{iLLIM9BFY(Uol`*Z&YhL;P*OIrYC9s|H*QQNi^mPJ}b3 zPdjB+EeP=|s@3=(E`Sm&F+3FOzigb=C?MHSB~tRm*C@JZmz>iKQ)ndAY5DS^O+qQ85nh36oYb$GK^%@>l_La9lp3-Z_npYTXLV^;z5Y1%Z{6%&5@4O$_p z`hh;ojY0i0oN9frDHMXwYIk{_6zTM9e=9Su8g4mRaG|-*gts<bIWE7~zH*&rEpb(BzafgV?~0Dn z5UHWcU%Ixj6c;S}!u7C6D=RsOQSg_9m<1&l)z7coQ=&-};3Lt|68lSKjudy;3yPio zc!0duhsUGG7u1~A_M5OqVw4bn3#eR076bU@7cX7DY>{MFn>Vr5twYHiTX+voPjabK9*CBxkC+2D9&unoX{ah?BbL_ysU2Uq(R$Gf z#{VY00WVQdf><>{M;&CGw{>PnaB!?!1H{n}%P3DzE_%_Lc4^Lw)RDq9tl8y}JtIwb)F z_(jr|MnrBhAtU@>)j5M3c~58<>^yeldRq&Fq4G8CBWBV3t0YS<;178im)< zm&(h;>h6HaWoQmPjm!REV`yDEt;}tjlFKUtZBMhY*oZ#Q{q|w}gX@BJFc`ZyD6u$# zk};uJ0HU&5-B`#uqp;)e(Yzd6h#x1_xHwFdygGRqNSY*`EWaqY&^pe5P8Gah&K|hl zh1W>}`SaT8ct$SLK5eu07K@)oo%de`rSp_MZfe(AX1;>!K#4b}UGIRZ3GB!VM^k|z zQ9DzL^Jp3WnV9(R&e`U2b1Fa|*sBUwcXE^l;HKaJ%M7+@L#E(aShQ-`q(Y9# zY>>&oDcQg?uu3Ybs|kVD3Z5=66CO{Z*$pMm7jUdt`|lqMr^GFn46`#YOkvbWEL++q zv2?bYt624cVi8FG+zUsNYA!z~Yo|ABnyj?T_j3;6WsXLX2~lPHiFI*% zEl}?Svj7HWA8HSM#dII%82fH)*eL|2%B)$xK3Pz*WLj&4*#R(3ng>Q z(}(^YjH!uMu<#KPw)^{BidB;|I2s(xM2xEn&5zDb-*&*rmW!-YPG9U|b6qohRQ_H* z(Oj-%i>>AJtuy%oM&_vJ1nV)jlocQ6OrM_TtXR}BBCx;F8b8~OmW&>RMh)Gt`RC`F za7=zXRS3HF^bj}wy+CCQN+5XAt0#*8=zD;#MmDu(KRZNmhnAwI@rPgFez zxp-ADg~PZShpFxl-zsyykt=FAxCV_lXEY=2GfNsy#-8-#Y0UYYKSjG&B+=*2sv$kz z|Dc0d1n4#A^ZGt_PAyC>pDyFO@~QrZWMK??P}n{OqO@9JftEHPRwvKTRK1Ym@P_ji z5<_0TbpI?LRbQe_RYIz$_P`e#r_ELzB0eWZ(V0WFp%kSqX5;EO0ku&b`!@UVeKo_7 z?+iNvF~MGwF;4h?&sSz{Q48kmE^FTX%(lB`*+3t0fhL&Cho^f5JIeOyvsdXwi)&oTUWmIi6JN=Pok@s* z0KSCuPX1IH{k=0|aQgZ|!F8<4ST)A)mLdJo3h%x@T(WQ18?RTk4f9Pht40{L5A^aj z@BCYB(6)K)>A$?udfli3lRf*NuWC*9V@*O;q?axHDw78^sU_LWqnbv7r`S5sfb8!z*wf)C~`vaq7vO@HtckJ5rhpeP& zX;jx?;R?3x?zEP~nehv82z6Q-CQs4BL28M1Y9pHHO}iGzbo^QvYk&SZ-=a*v!Yua} z>~cZApdwI}^TS@eS$CZ9Ql4XUd-n&{ktr*1_NL1GA!mo27NP}vu=xq8U|Ini(b1zv zZ@1kYETDxDS#D-zd`VEAzV~mq^YpKLb{`d-A0@4Y*VFl>Xw^>|CZR}{Y;5+|rTcp? zGG76=C$DyH;|QG{02xAvBtuU2na#T5dfwAiw|CSlsTW&s@@j(LYVE;$qDP>}%=@R0 zlK<>Wn~2>2qK#99vQ7fD*LF!~Pv2F^GQsGnRE`3G!yik|Rz^1oqL7u40dtxG5+2M`no6k~hZ&?~_x{p;E# zeGLNh=b5wTaGU~!EWfxg)TOP-IBx2&BV4TDx)nk-kA)3_n?{$1*L7yko*3{ZH-?&; zxj{?bR?XCEHO}`dm24hQMwF!mFA{ zg#e2H{m{V+O&1u=!*^73ZzL?n3CVQ-V)2-2g?)g9gMQr@@x;BQ4sSb z@}E<&wRU6e6J|Hu2>TKYX?O4LdJrflTz01h-SvG}h1sE-9kEzVj=>1{L&tdN+(Z27 z?j3K=H6I%Ms?h}v3M99X0?v)43G-`mK_UP#7>gwv&jdKW$T&NnXtpB>K3=)x? z5vUipfT91l@4OU6WIvuh%vO{_emnlaUb9Jv0f$Nl-Gixt7);Xt@Es!ATotFUpWAs0 z*G$rKQ?20U+u7yYh%j+cKEF$`z#%-ippMW^C@EqE5=dlfDsn1tCV(0BLhzEUH_7!y z)PxvZB=4li`4aIMIF{xR-h228g1fu0$?||P0^OnTdSW7lAeOnfhhz)R&@)W&i=C4F znR*s}?^}kX?$vWT^69zXlNDd4&V%t|hG0^=pdE@04kp z3OvibQG#YxrA_o~fT>!e>%M@x&q%pA$>jp`<%|se00}K=aitFf3#=D} z>hQ_y8yg$3rB6y49BE)-;iH1Af~iAe7blxu`W8UwS&JJ=6l5C|AW z6DSN51UVsalAo=r36gX{CWW8Q+O?&cQT0$0o#1Sf;Bw9o;qWkp3uZ5!auOA%D< zqy>G0u2A5MQ7Gx$zB8ovn_|5XT4OlQp!7zpn$YAkj$|r<>*tSm>%T**e|Q}pAKKSF zIlWb&xiWdko(1rT`~oWY?-+sQU=5vNeczb}q>lEIiZn4XcBK8ZGyK(^J9pOX$lA^1 z=G!+fnY^=0>$x~CjZc4-$=hhT(qed|NNx}f0z4eSOosHW$R_~)nN`_vj@e<{1pHfQ zB-bUMQXQyEW4Bl5HIe416$r+Oi-1{|S1a~`~zXTbVykF~| zzrkd}aZOV|02_JNk8dA52_YIJ;xU~+<3z-h2u*kHTy9ZjIL?{Cd2x0-`_GT>tv{4> zVwNUU;sj4FJ0OaHb}6!&X{o9DqxbbVKukQdgp3jGaBqm5U!fRj3RZb!IWpu00Kagd zPeD7_{z6sK)MM39&!NtyvXqq;((n$bY}ZskJS9Y&V*jUciWVMBxQw6j@{vRQ-J+C7 zjdGY7p*BVj)EPw9a9pnF5xW+_TacKT8u26`lFzH%i+`J7a>aC&;1Tp+Y)ByEnVBBn zdUglbF*XZUoIS4;b#Xn*l>~{Bgnllizy$!v!F9N4q4C8_Z}g~9l!XY1?uxczK0{cf zg`g}#Lxy5^jd~cveTG>JRd+HeL#L*9-Ey>%xqz$cO`A5a)aCL@!LJNaG}C>KWu8B# z%MAQsc?(pm6M%|Um54lt*SYTA zjiwNLvw&D9#RXIOkNP|-Wr7=?TENGqq-221;|~mFCH-vaPZ{4>lp#*xiKgE(mqX>B zwNBrw$3=BmqjmYaosnA3ck9~K0%yqc`K*dG7Jy}Q=^mzhPLk$s?oYl_Ke>j+ zgkDIP8sWhQ)wusF$}D@$fI3DLsFBqRl{`5mLsglSGbX3gfDtj`V2NhX#|!I{7fimO zEoWr8cCC@+PkEvnOEHjWBV&yIOWu#LE(YG=@t`KeT$MSVR|`;murmWo4wKM^kbqcD z9aOx@<&2jpj?8{Hq~T-K zp)Ect@QRN=kQ?I91|l)MjC}>b)Sx?K`Ney6&L*ME+}(h56ffz{U0l7w0#gn)A8|~s zm_8EiIaFNV9Hzb*S>P-V=k@pZC!z?QJp02`nLy(>d#`V(QW%T6XGujL`lJu8r&xgK z73WXx@6OW|Tge>3ne~abnWj$q*(4Ulq~U1AMKUI#iHXdqtG|CrlgVRVPAcQGzWc?a zoWvp0K64jbS5GQ!1o~z~rCs=>cx!{iIJ#1LDsvI6Vbb0MacR&Tkb1KJ+@QwJZO6wo;}9vILBwL ztDgOvfvi-iyu!>uCkR1#P)>TDoh303OB2c**?WvaNElq|^Syg~^knI~U^tN4*YEp+ zdx>ITe!l~EIBDq~2-w94M+`6yD{xkupQ0Y9FjDch_ASW*=MjN+m|F;^R#F-6cTh0w z?MZtK5W?Nz)nar`^x;}EZD z3RElQqi1DyIDFH&o5Pk>R}P*LneN_|81KL|>)p;QHhHwbvI4_sZI=6PT9Xx%bl+77 z*U#y*4$p7iRra*BG$0ef9JP@+=J_e14THN~w0{qHllKL9TxpK$RvZ|Th0K=iTeqGY zLbKE?oZ)ffKxw1ur>wjEqO_E(MV@}Vj54f87?spIdG=OT>!n91EV@q+iDPPPctC-H z>2Q$@B4NBY4=laaumY9Mu_6nL8PP@6>CS)RIjcBi&+gqZ+tYQ5B9=@*i(TulEWF|B z5L*&_a9W7GSnLaCKKW!oyaNWD$T!f^@P;M!&xAy64SGw!>ItP>A^a;uG7LwBIneZ1 zbPr^FVJpz$3PEGqnXNM`N#*$ARPX;}^8n2+J?W!;{PFk|k zT8cxWK(t0rQfI)AZ_P5ddBU|*oL02b zm7S%Wk!EdB(@|wTii5&phI3AbLh5Z#dc$VitJ$;aFT$#){z`nu)b{@`Z75W23{0cACGbGUH8l z+S%D{*zl%ilg$`b9NYe;l$a@xlh;G8|QY^j`Eu= zf2u_CZx!hlSLXVh0Brs@n)|-Y9{hV-6D-jZd(IBaFmx+()_;hzTSwc+4qV6>GDK{amKDd!bUOR2E!N6s>rGrxSdG-n1*5bKg> zMLgEPq+yxx^NXbO`xn=u5w&n^qV$*-gZw=~XRPpW4jJ`ksPf0(5)zwOj?hdjELhks z+NPz9aq;#YUp}2-cKiVIFB22{!V9fy7uG6)0Yj%u>-E)M-wAzUf=Y=4VjUsh?@5T+ zKGqLt3dY5ZUQAY^unqtI@2s3rf%!wCZi=JkCI|S;913Bi;f72ecUz!|=D0L2SCeyp z`xl!(X-A}$2p!ZD2Mb|{dwJ#H0=$7Z&1NkMvtvjKqJC^)!j1CguJ4blzdnZ1HCs#T zaFaG%up*?=J)I;36x#wluU4NF;FCDfOqWTUHmtM3Q*MhiHIb%@WadXI7gv^KbV^9z zaSo!*$gdQa&)tMebyoE0`h6dTJ`@ENX_OF+2|F;P(Za5tB+;*U@XXou?~-~wPC)}% z$fh~dmwErsLa&qIuC};(czf$3T%190wO~)={0CnGg~rb_9s*T+_ojJI4j5--ej#Aw zj9=$A=SA%c2WIEgifO#hN)T=oRK0iy$R`#3Fm@xN4q5lqskJD>ltQpf`>-bm1ARi) zp+HRc@l7D$Altpmkud5ra z3ZB2~!i9?{KA7e5gLnvN8_vs(3>xF99*iBow*;&5)4kNR+CP zd9mv-^+HM!&EZwl`dznFO{kK%KFzJI%s;e(OdPrs^*GGswM{Ba)?Zn@K3`Zz+jdyHI7FTm(y-883Y+ zmb)v=ABrC3WssppumN?3myOLGIHuaW2OTnc_9f;&o3JYqbG0L*!ozt;XS)7WHeB#li8HqhD_N8brwm>Z!q)9+u*lFn`Pm+rYqNFh zC1$~m%fjrS?NZ`E%U~uKi^~vPA)Xj)Mx-UJ1c>WGF@B|_R;^01&>Za1b|tp#Pb@{a zFF*an&JBrY>r7|GSqec+f2K^SMmB`;w)u&I84EmCbIs43f&WUM#hHEnl%K=VXaB6 z`pfpk<|lA0#$nhb6gG)zsieU_*O91ls8_cpI9SO_y8Lbu3Zm@4$ply_@gnjXWRdk9 z5Fp&3vqFAb|D&fo)oL3_hxKj&vBkdj!>feTJY8cD|GIBq_PXTtw_F~;2KjW2i96%R zov%Sc;x?^;jv!!A@Dt{=$6kI?JeZOZvJN7qd2XFp$x$?towE#5PHW6qFw^L2E;s+x871?^Y2rzIY@@4%#i_8c@W=5g(yY4+ zg;RQJh$u+5&X_iRy6@{lPbvcz9UPk7W*m16xFEw1JxnH> zBqXg@Ssr_!%FJ{Z_Ver~ZB(yMj#2~QC}@5{zkQGk-}}l{cG1T}Hsq||bI4T%Ah*S} zx3-|aLBpkUWF^|~v4n8&@z0#B-GH{u=n`x*m{Z518Vd2b^QE`j%crrnXvNYqC>H2H(vj!t!Kz9#uBcu zi?K`5V~53qRAJI+vK){geR#j{zhRS!sa}v0H9FsBJ{!ee?KL( zy?utPx#Km#Om{oueL5m~qIW;Z=Kg==pSmtd!Et1&V-hN*Y;Ivus%cuP)GgL+;(gXt0I7Q4pU) zRhf<($m$gh);e0e-XuJpl+4oVWx{2h-OfBQAeJ$}XVN?DXGVndo`O;2XN#T|SgpR} z^2+9A)0VU?FeoK{rImye2%La(Mc&_cnaf3gZ^RG`=P$|{tzeSey0J8ZPqOTbFr$o9 z{-!1`0}j_*o?Oi~yDQ6Q{^3|Pn&l&g(0=hr^z?)I*;e+E&p=XMo$?R)0_V;JrloR5 zz`fWmohk&}2ke0l;#+_`MRATRN3HN$aUpu{GwZ%5<_d6{Odg2Nl9a!_S1|pO89-l{ zY6U4^yr;YU0H)I@9j>j`D=lTdik`{p^1Gc>SNm={>nC@2q|8g6*Q?uE!w(~>SKGT` zz@rbRzl&~BVU6q;#2?m*1)L|e2UBH^t6r?yE+Fx8enJUrt?{488-6dFZKqs#JcDrf*h}Mf!1FeTs(0&B|@2rBEM{KNQtdQ_{3shV)Ekiut|G zU;l5lCUI5&u&iUEs=gZE_h`#VT0Qof5`Pb>lUshHg7bd~L7+rNC z`rj9M&-l6ohV_rvDN2Q_Ue~7khjvKYCNEF|wo931{% zoXj2`y*2}<_}R+gQ1)}@yGs>YM@ZW8)=Z1HY-v6)wp~YAW&jPj07_G3@sC?}?z~>L zJ0T`94Lj&teC6>@uBZZhqnqnUPwE%3L(oiz0B zSvba8UB(ob!NDLOCnpE2X!x|O-8l2W{8$Go2w5M~e|CD0Z6txO{IfNo7pRc4X_R)z&2+~7m=8p@gY7$%pK;YqYY zQrSzi4LdV94XK@pN-`}a^d^9!4i=rIRSVM$2p93?gAf#Xe{hXN9 z4>=7MFMJ3w&=BVx!v0s?${f{j>WO+cp<2g_5ltAhUg3DY&NGPr3~&i%M#7-7KJ+sU3pi=SC_MGap)e&>Rz1+(49n=(O*0ixRn7cg6&3O#y= zzpl^@KhwNyb=h4fifw>+K~KGqb;s~(&EJ5+2-OQfB_Z3_QQ+_#iJge`cfS1x1JR1D z#-+Ov&TX9uwHh}JxT*`{v&7Q_KJB-CIr$r69IV3Lxo3UYnE~y0^QKJ{!0`_BgrKlW zAxzU7-b%1(Ydseht$m!O$Eyvc7|GAt9rq%khr$D<)mq zyI~%D`L-{r-dwi3_gh2xcCoku>0=lW%Jo#__c@0S3EzA#emtH<5^6M5s{aC?&C-@bb~4Iehas zID2`KcQ0gh?AcS^yqgDVh(C~Uu8oQ`$-tmwTu&Cq7td_abcU(UijKGV6~fvh9&>SF z9c|orufF4ABE)9oNv5UTE8Gt7u290Mr&hJBPEB06Lol{L3WaAM`_=JXwTOx0n-hJL2Lc>Hq&${s{6 zl_^j}>M)=63x?SPj~hqK`6Y=~m!E2GOR2!d^DN|^P_-H>^)MRp9bV0AfWAxYnHRx# zo*7Z1e{Elm{_zLr4mUb!E#+J} z-+&0WTEV@-w?b%;gY%)USmjm*i}6UlFszu)L`)9QC~IedP>@f3ofi7;Qru?RuP zmrc<{dvH+|%O~-?YwjX^Z+qGth(0%h`oX{#_Ar!yloC*rQ>{RcU{x@=`u1JSNo0kp zNM9I(4wc$sWs>fd{E5MJaz>Dxqvz}ya``R)MyFDz#a-z5!skJVyLK(Bl6}-uk9Fdt zwb7DTzLSekP6MwVF@#G315yl0jx~8&-x0~U9lT%82%AAldPOX(Yn|B1^k^MD=5(QNbMFLvsK*tIV?x!0Vjrv{4fJrTOb4L8o<_~ft%BXZes79C z*~PM?sKuyZ;cNdSH1@)l@+&cZnTiFYWCjR>Ay?Y<<2_ozbSl6if=*CY8Zc=2MMLuw zCHBG|%xcVe3VCYidW9J$&^$C$LA3T!fl+J{$~aX;mVYBV=W)n5nY=cgR@Y1>QtT8v zqCp$KcuE?;cHXh5by=@DGV+GuE%D{uxw6_Xui8hc+ZE=ZZH zQnNFoxN=bB9tuWJgCdVROhN#2n$c4@gflG;Sz{0E*_fMau)?VAQ{|3?M2ZMbF}*Z$h0)mGlMv!MiACJ!AYBXf2%=0;d63pwHp8gr_ZgN&I>>VNBmImVRL zqy3GMZ|>EUl<8Ba3V%K-=sq!)n8Urc*TkVG25&kHK_m^#pP^1-)I`2CnN%&Q#jvkm2)u7(g*R_pna8%f|3rKuV_Yb?)C!{J|@jcKzt*TU4w|9_@=!HQa z)xJJ86RxNj&Axu(#Hr;&;V^2LN2(m#^Zhtv#UmRjTSM? z4iOHVcxr{!VY+;GbYo-d)6yXKDC>E`U;>j(&!>iFq`+0YTI{=SDTQ!vR;yQ0eFttB zVy{{D0+t$81PBmiw-7mZsqqM=aPC~F{sw=kIxzCVHfpuHkcz;t0*;srw>X``n*mhj$*B)6nrXHJS0styGk_j3uaK3D!6}kq3C#4U=xBZS&IZJl zeBxahub7y7qxqo~dhyvZ)50Q^5N6tdBgSz+wAgw&?%3g9bcQmT@uQ<78cxn1tc$^1 zsAGh5M#0tS@AFfv(>dDAlFKPnKUC-Z0E`r_BF$e$h8sE+Xee~1y{#LC0Kt}~-9w!Y z#5*J=KGhr}E!pY*e4AFVZ+S6V2dqvZt&_i>mK$fV=hT)vM2`LGJkj8cAs*IThA9`*fp=n4{6HP+ALLlC2-a#DB;b z?A)CC;a61J2=&Xx&aS_sqnBPaA?PTEDx#Uw+M)c#QNaD%PDSdz0S`pXmjgZ6u>sM-{6%nDd5fL&!7W6bKLjp zUaqfKHSy8xxN>n?KIl*3&cC^%mP=TuLA^96HBH)L-l_vX9exno=9G5Kla67~7*Spu>bLj zTeD`gX(=Ze0P-RwJ4-53aqkAg zMyH8wvybCXmSG)VZYL(XYb7}-T-tfNd7-0_Cv&d*sxW*)$2+%tW^mm(+u!e93WCMV zwn^O36^B4irB7)1gQn=*OzmR}iF%U7)=*2)@A}Klx_#Q5=gCT@$BOlU z1DxeAeW=7+Y=_L0=Zp~LhNAI!`m|rsP?Jyr%%I>wr>oR15vp2yTteM^O;VYp0mgUu zgD;nElkVT|03$H9M8Drs)+IOlUR-5;r!2Q)z^OhKg>6>VTW7XLCRnds8yOK%>Y?u) zoO%z*6kI$$?TXuGay-d3O*R)SEXr2h@=Nk~KE8yQoJWrz=YG!XL#iw+(xQ5%0g%r; zr_$2q*=i~NUC+rL@%XQ1HX=!bxj#v~UnyrK&;a!1tOTQ{_4mrh!)O^a?kE8x3ydo1 z?&u(yq9|#*NdNnixLOLoy*hfcW=Z9k+4c-JoJ}2hE%h#AUa|6zTKnW8mX^U>Jjn6- z$R2fK5~=KzCNQeDr+xq3GHh%^N!pV0%gr|3ny%)bX!b2SR*f@+8Qe9Mq7G57Kb2(F zvouY4nKo^5gVHKZOEn9ebvSFjAo;x?cAl+KEWVr+@bS`(?kl)~CzOH&>O@h!Q%bPS z9llLQcHdVO_A)D1J^w-6@bwDy_KI;s#5$hUOZxTN?zBHc7y$zR= zIaEjzlC+X)B~xZfMF?4@%o|Y%nWrQ}6qTYVgd${!keTdKTS79=Nk~Fw!nq#%yzi%T zKAzw2oa4)W-_)|K=Re%{bzcKU2Pki6W~%(Tuf`vYeYY{XgjO)Ua!O;Ng?7f-enXoW|<}q8ME~lD)cJZ^) z^$jZ8o(X+>Mi*3^c#s+bFB1zELP74~hyk{G1dtOj5WkMTW5Y59OLN z6~2qO=O(|KoK0>PD^EP87%^6X{N3pv{j#&=P(Ey2B<;6qOxNnOMioY-;RpaKn5R$RlHg6z5JqvyR^M) zjvl{!yY!!=(Ell~th_or7nh3Of4fQM4w#a4j^a|ai*sC$l_j*$g5h8DlMfiiOaO5# zEd1^yWnNMO&%e`C=lhoQ&YY>k#lZAjs{2A4h_afH@cp@{KC85o&enN0igV+2{a%Pl zVusLU0SUT-0F<=DMNDaovpdavY0)8}qW3YyN^RKva7J}omV5pFpkwFRvd?#~W z??WEwC9|{+GNU(dSCt5s^7=svG3__GVAR}pE1=`fq|P&K#3|b3M;um18Ic6nM^>jD zAzp(KlDKU5ss^r}lSvUMzqzl>)zrY?y-hyQNFug#6Am4?kRtzV+N<2A%2rH5Xpe)$ zdBje&jnxmifRmd0ji!c-b%T;IqMEIS+^8?kd9(Z}d6V9y>z23Fg{Dz!Omx$7yHs~s zP$;mA$89Ga-}=ZwiLVcQX2Te;Fy|`!g=@&*iU27T(?Dt&kS@>#P`VAyEn0f3M!~&` zU=e}F3_%GnnG2LhKLzV>cmz7aUzp+R!iBfB@KV0ItN>qO8JW^6j}W3zH^kr6gOe{y z9wy9d+=T{Jw7UmuHo|!D)!CfO0hh>JzyZKV1QEO25i}zdhJCZMVr5#qyu4UM*E@^` zoIKs!IN}`NeDr%bKvQXHk)Z|Qhs2|ntB>AX1AB_h1p|J-3a$GQ*yI4a>jEz10MUTd z3}H zQ~!&gSQAs{Bgl<)2}m84GtE1GkGiyn)#crJw2xl-@W+7t-eD({&!#AM z01U!e*mimsEt~KWDA>P$Pvs3m6KXq%W(sxMv)}H&0Lh(5(Jr?%FjHX-*Gf7k`qA7c zi8`n_1Q(h5i4{<1vtF|OQBm4*aLO=!X1KS(Uy<@Mnbhb7$`DLLJ@vd2JfwJlDa5K* zjJzkV8kzEhX1*8y-o0CCSZX0Y3Bn;-pvYZnC@K4wMH_Xu-1) z?$EP8C=Zoplc^8UfVP9n7Hl9G@GgA8(!kgT9!7JKpJ%8s3j_g|i6f^>pMl4L}PF9%a85Z|YarmH3b01S_pvE(IbHFAk>_^k#Cg>KEhQP5&7g zA-T2>Nr5W}U=))rpQ|fvgusD$FJ~hMb!sEV>9$9p0SDWTLj_hWjnx`g-@$<8>g094 zvbEJ^K#Yf_1J)BA00xX$3ej(P!Mfp6PkCcD_ozM$xCgc(!c6t7Qdr| z?ywJ824CcmS;Zejc+3NWYFDA&9n`pe-rL%#0er}j#@!SODkPfCu$>e9fZiRpI&fI| z)tPjz8TzUioG^=JH3JKCMU{RaOd0G1>mC^9LOmK*Yr}-tSQr+;NkJ0>4(OXDrKK5n zKWG@f{!~PS@d@c4w6G9KY&f{Ij>LziYmx_jP-wSt*B{;_#`!QIA>?r%>O@oz1Ha++ z7Btav5l%D6AqZvV0~`l@I$;>>unAxU)Or=racf08L-T*yLV4(V+t@B1o}duh%0EMg z!Lcq>f0|5!ItIH6lzi~Th=}oG<5Uvk=R`6SZ_It4BofT!V4^%N+lR8escGt7;-9Bl zB*g|3fZSxRxrcJJMV?@8>5aoG-2T+5Qw2t-s3ExR`StGTUBUf4EN3+|Ff2iW+P^Bp z?Tni+2yiJjzh;=(QZgoZRGS@L7V1t=pFvOq+9MndyS_VSLDay;P?Mp2(?L~%u^FCC z-H@8ly~?$f3E2OljyK;D1Tv(UxcJ*aa4a?{A-o$YXU2z6Z@)h@adie=1;&Z*-_zHS z(fo3AagjK0I;liduiV7TfYk)X(3X=70C71G+Qz6iC(!FYY5bkkMe3}*vm9#=-IPp&@^a~Xr^Nk00k8j z;ezjmPtMh~%DZy%`_j~4ZLTK9U=gF)4wR%Ng`I18^Y^EhE|@2YQz(SB03H#5bL@{; z2?_i zmEyfE6aVRgn-ndaJ2f?R*n#_wqAEq8DZD9w$g^>bvew{8z#2cHjTZDd&}=E5WRkMv zSA-}YIf6h`+=UJjU7u2*sIL(%gr^{GamON5$_c|jb4AW=N9+WpcYUVixiD#Auz=W9 zFH1{e#Czq+wuHyWLs<%5y+ZL!p%Ap2JS`HT;>W$}b2v43dJ3Zg__X7b#NS7T0_WS{ zXbspE7FGyEGj;=#mJ_32kv<_35fp*Wj0cB=s&#qR-h zfUO5P)56gxln>UyW+0~C1{(QXPVkoCS?_(h2R50E!B6f2rw{Q@a56zXhIR%9^=-&Q z|MpF=E)Pv+L{BX`M(7OI!$GDzc`E3S_z_`#qHJDlG}LcX@LN7kOl+$=j)MYA;zfoq z`$Of~G*UR)VSpl+E-{6?LO`sBhBo`AmA!(h@c1ruzXVqV6fk)r$1NWUsy_S%IyFta z<$&8Uf*ZF4c_S{LtzlxCfERlP#5xg z)nhLtT{v|r^n+EPJOjo6TF9rN%KWYT5h@+HBUl@PKBdKiG*(i0%SG|Kr@JfsWa2f? zMCl%S01*{I!eV zYfCsCZ1Jx5mgO~!N%|0=^PX<1&ZxhYR$g~uXwD6%oT;HkxI;I*kXd(FTk zsB@Vu!oF$)QYU}$LQr5dk_qIyMpcLQv2FBpnx07rp6+U=JWsr+wMqO6d0 zXw{Yds5(vuTnA{}+23ApI7tY(XO>qmmogCXpS@>p?jKcebNx-9+|v8?xeoxY7e}UE zm~RmsLZ%J((l2LY%#)@U!NZbTt}EsZd}7Dl`BnMje6j+|k&~HuiTW-E8G`CaY*K z8(u+AeMgX?$3kI#aiZjWvIXbbO-nA?g!W-kGMhhvu0CH=hS z_;|;M!{@wO{=N{geJ+29Y4K}m&sEl*mA)a9tL_TJ#kH$u@2LEL)pgXIDQWWIn~Gj5 zB4Qe}%1s^K5N_j$=MMZ&bUWRDquUk#YZ6WRPjvgn|2lI&`Hy`y3;Ta^(XahC2VU#{ zffoz67=8fx8&Yw9oZUox0@7cleo+!Pz^%G{ZZw<*M_~WS0k;JARAbBNx)|ktkbbsCn+RjTO0cC zO$)D}U$=Y4$>>#Kd`)4$83RN@d)>_GGNhD)Qz-b1vVe#|AKx2_{|Oe{ zMHDKP{YgaLkAh3inSe7rtwQK9G6F+Y5bD?d z_02XNk4vr`WG)={28lUy@VpCAninyU=11cWq*hY0FHu}ROws2%iY!0 z7(^#9sVjBZ0(%y&BCn`HPbSH84QJ5sWgS`ZxhfUHxWylj=dLUBWO?`k_9; z=rP0soJx%W?{xdPm@=l!*i5R6{K_s;@K*nL7mM8Zc! zo5CH!zK)g>!3QRr`&TWJ*R|U)4JUaE42F3;k{ZiRF=|8nU zkOBwik8oxMWv;YOG1J8 zqqP;Z!oui%Zzd=~NVh2>GC7JF9K0%VVy&BZ!v9Tw&6C*Xk!{1m!?E6LEnl z?I|B<`UrK&e~7RUMOB9G{m+#<{u(K~))DG7zEmD#?bYQAf%xXcriL_U=A!tm%$T*4 z)qEaT@fNu#n9uZQcqC9jN`q+b(WC2Qo$f6nb;?E_!&z`Lg8g`cKp;PkMRctW{VsDy z`6p!oalSs}`EiCCJOKZv&L#W)v;|xiUwuFvjHcUiAqQ&hc zwAoQ3!sVLG1>0*}&$4S5uFhQmM#9eMhvIwbHZ=PleSV;VV$6t{vwj}}JXSt=RDjSy z2sQWZOPuY3wGu{Hn4V&mo7ezB%fRb9f2G(hVYOrG$xr_kg%3(-grWGuHqyt_vwd(1 z`Y6zvkg2rt8XPb*z;m9mfF!quLFS6e+MX%z7-3%`I9c*xhA{44Qvlo!%74%|LC>lBb|x+r4n4@Elm!VMlsl0+l#Ql2(5s+Z#6$^0v33Z*!Q~>J{D2eG*PtPfC2|ml zf7{=gtWK8-xLvhiB8XaleC~t*bzhYAv1m28CjrF;+YD*9auI(|6Suly3{>YJuZtZF zPwpDWgD?a6GMB0s{8S&j)ku2LBB$(#(sS$Wp^_A9Mt;y}-q>65SvE z85(80bW@Wv%q+Cok40W`aKMiQ#U*tVX8W+mgvE@MlqW)?5{5x^vLe~V$@?vV+YF2H z(z>CqUZ%MH`uduhw$Rh?%KJXA^;(BJCyB>V+Jk8pcr$VM32>=L`GW5Jy1EeLVBidL zbHlS3gX4&_%}4m^^5A%C`4ShLtGhd6prSSyG1!3AXfK&2?D34y7#AC}&_2z{ zK{UGK&O1QD{1l}kp|-BZJXmg-hK$_!88QEh!kg(A8Gg2 z*Jmx(C!7TVAZKK-t6&ZouW(hKi}{@;di}!$1Nz*#b8MUlhLWqYHZ07^L0&NU2e1nQ zvdy5GFbpnB8D`t*x{&i@Xnn}EjEkOzUj^~yI3d@akl|#p-qAkCRU+Il?O$1$lXKbH z8k_-url>Q%eEECY7ke6K0<2e4?beN$VO0X6av;C^N(+kuZbE7Jov@vNwQhSmD1dnU z#DqT^8xU9&LZBIg@q_7Qz?22E0dUJme?PcDVD$q$?68j+Uc_Pp2(~^ceuux$5}H&i zd}181>b(!3d?6|NfUJ}hJ~aqR3CCo6w!+`Ca5&wU=!>{retONcbK3VE7|sw5$fX1K zP8Jx!FjR|us4?kyAiv@gd7V5BLZE0YpgoXE8k?I_GJaq_P+$ZyKHl!c{5?U#EWM+# z%&Jisk>Nm;M=GX@Hqek_<48f6BJQAviLagd%&&~BtZ@)R0NFusj5Djr77k%xR?^dS zJ4)8&;ZFcwz(s8HW|Iovj^Sze2%Hww5#b!@gqb8FgoWp(BVCReG*rgax2}A}g^UxX~5qg;*frE7e zRxyt5chuJ!B;*s%^0CFGLP^~`8r0yDOag;$-INh+^;wZoS8m2*>#@A+!sYEt^A?7( z#1YzvC5ksl=DJlo$P&7-S1Bhe3!oc3OYrZ3qJcOX{H9PV4;HICPVvFzf_AeBnY` z?clMitG;ko30CI2`sax2$uTf8@za0^5(;U&Jp{?Z6z0_{B&&k)0aFOfL5@YMDZe=SBj0XC~!mqQUaC@?i(;ZoZb34rXU%iECA04T$RLw8xmx_ zP-7y75>%5xOzq&+f9#kSz#Gg%Nouga0UA#?kp~p@a$iNfy3p@4%VUE*b$i$GDV!~xw z*=Gh{3S6Gd8v z2*-S`8Oh#%K;Xv8GPyj&OMU1@?I5@+FzuyKW^tk-8wh|oDji}PJ&vp(6h>&C*8NBk z@BquW#$)Y6E^hAb_!wN^4FK*tn^7pl@UApGm)$o%kVjB4(a5k5hmGwpI zgF_Uo-cW@<10X&5S6;>ZO*E272WnbF;3x)KPQ)8$3d}|6Cn4dEl7OIxJAuSAGLmZ` zvL1#b@~9q;A5i6i-%wC6o1eG>>jT@k7|k@)%dGyB*L!G=J3$mdMd+^Pv{vkTWL5U!rv6}yf`b&X>Ti)~f8I-8@S9ps1VnIPZ60I6qI&mYBoLrJs`duCNx zHbK6#v4PqQazX+mmU8>S{_$AXq8Hk*{0`>)K4f14tONt^i4B4T(M2A&6IWtG{{%VW zegMRHW8i-f&cnR)G}s;(u0MSE8krP-B%A2IK~uaGFj9*-AAHXVzqJ`J$RmhQrb#vy zR#uQ*UmIn)@C-Xik;08kO+guGA08yI8^35%Kk!{>zOep5kP7CdBje+87(HMn9ji{* z7ZsxFz~BX!kA%%5D?`)s=Rw3_i&rdgluyJcJ&?cZ(g=cA$y_*5??=4(h*BO&=Ac^Q z7R8~BVJJm>YZ%xF27XkndK@&>_7XLhks1O+Irvn|^0vt^C{li9KLGfQ**vri59+{7 z2Z3`Enm}_$wxGlKm34(NfyqJD1T9khr@MEzJgNtD0ucK1sMvi6a9sYUCnj^vh!4zy zv?{=o#-)qMdm49`_5~|z^(zrHJEXuu2x@F<3bG$t!msHHu{q9wT!zj)ik7UFd(h1C z-sY#*n&}3!YZxy9ejnsbUv4XJ;L@b}Bpj>$<8DWGf2T#ycgvRxBTaiw^KM(8H1w2k zu!MpeEW(Kv+Df$gTRxU25`6o{uurg&DU=-{*GjD7Qg>^hScHuM$oB(iFriFaD%|0N z5Jey3O*3Mr!D?5gW`V0ei-ojmROul-ot!om!t0Ylx%u|@Cz%NmsA&!##xBG0Wgzk{ zPo%85(V^5#r0z`s%EI`K8%9m9S5s|F=0j)mPT z21ML>(B+^14$|dO@3wm}3#W2?6&P*)u?Hr7#l|5Sv>^2K_Cf@VPmew@;Cs%be_Z&_S~){tl`{{`&^nPnr>)M}7>u-&DG?j@hJQsP zvpao%c=89UrNH?*1R)d@RPO~8gTra85Bw#as6c+$rjoe!(kH`{%kt$rEtH@Rmwp$( zrwl7cFDtXv>1p<)0xG>GLbr#dCaB6&+awlWZ-Y~A zP=0{JcktNHH9dJIE2o{8C_ixSrO?MoZCH=o)$)+qSC$2^l^WX-*!QUq;J|5fkDE7JYG(!@Gch zIt&*)!?2&r9kHL8;^mXF>;d}0M-kI9u8?awpTH_We#67x;4tN9vHs#U=jBjhlNqj7 ze4J*-O+z9tgP{7f8MYxl02v^DrKf=vE6i!x8PVwDyxy~Ck8DFtcbt`IdJ%WeY&gF_ zD3|N`S+cS6U$*iJ3h3L8NJ*V$qhYrhc@_avhO1$$ax{gzG6`xc2L?V5V84|oQseXE zT|F}16YQ72Ug_OL%qrMov85d@P?nj$IVU;)9F`){yv8NjCRC72Lc2ZVob-%9+=5H^LHcN&txLsUo13hBsVfP7P3I@r5tWmW#ELAFE zLO%|v&oU32TvbmxVkR0I0Dl5dz$hGIU)%)O0x~kNh#U8A1f1=?xqOMLZEy;NbQhlO z%(O{6sg&5@Vt()-?tHxR%;mMY%g7$=aYWygn2^vr>}Zj(S&1zoev(?`Jk?ZeGHEhO z8cj)!MP(l=gRt!$vbcD(W*VGuPA26hKoV_&!?*Mp`%f=_fBAbKe%iA?u`?Ph?LljY z?B(jdgPcr3x7ggDXE3IBua7F-vu5!WFcjsdhfTI1d|M8?ZTU;1nSYTnr;Ok`Ic~f8 z7L|s%D!pH9NPL~nPin?g3MJV6vX6aw=)zJsZN#-*FPapitgt4P20JYpZ_900ZWOH?ctI!Z3Qr!UIcVHs@;jn0f2i>^tftyCuVoc$^74zn&7f zt1Arts75nN3$!L^b)udEPpF||Cg51{nI9Xwac`iOepvq2A$?BNUdd3 zcphj8yjlMYL;rKfBCzPANZpU7Fa$h}kHf$NK$N#GxGwOv@bgfxucVCR7&BCxlF0NsCZTRYH zDE|$x%+#LI2lA`B;I*Nxx?sL}Km6_)g9AZ;;>W?@uw~%=`_tH#x>nF4gvy{PFelqT zs+vbXx`~nYNmC1MX%ELe;R41i%$wB9A<0_4n!b6KHkqWu{+7(#Fdf0!kr^>lnVE!+iacu-+#` zbWj{U4-BvlX9x^9W&!d-$flAXzaEP|JNrCTZEGMuXbi|v6d4ny4{Wjx`U934h*M= zhqBCYxr)GD^i>9tEWPF@1nz?7;lUtI^bw5ZOU#b4yk@4Rz)Rt}ihd6p3l-R;W#_0X znsA^LwjLbJwDV>YfUVHOK|tH%7&@z&v=<+TEO?1%4Gd#IkbwD2VJMX}exBGJWC7~y z?{-Cyrc+3xdF47lB@)o-n4v-+FVf2-*%YQo0C<|A!!j<$A?nB(yrbD=Q`Lp~2wqrAwL8*oBH$Q5LiMLW(;(80#W-Y}sGt;(|>RNte~ zydL}^+aM_}4)<-ixUO3ONIVK=A!73~GTz8ILkKzWCK$+EsmH{deL=X-K?V0q1#1cg zUv~#EFI-A;!kqj!Syw^;`x7md!IOvl^pS&~;$=Xc1`1>m5BF3!V#AFQUq60ae5P~F z(HNj8lDvYf$RhPnYXDX@e@Salvtx-C7YZSYV*8~qrU#V*+SDt{{S}yr^?vyxhVa18 zpYtHJQ5Ui{3}J}GoW-^#22`Y;O>oLOb|UJjx=^zeKayL(1Y)6|iMmeh_3MLTz%pEf zl`W8fLn!!qcae!{to~RF-RmT+8|$Hxu#&~i4`&MXm?3+Kx;&BlaDRSb!m@g^9)Ryp-Tb$ zf<&!)8vu}eyLS2g{-z4snzxpO0&fMp(A;OwtO1k)JG^^ZQGpH(w4@J+&#=p94y|eQ zn7h3C{W30DBFR^A+e_5itmfj~)^nQj}JiOL?;3aH)2qC?QF{{f1 z&F6x`a>vq=boUV9y9xjXhQI*t64w?JlRk1wH#?lEmFWHFvhUzdoEgi$Li4IoPa)sO z%mkj{?xloLsiH#(j5ip-l%sNRxw=Jz4#ocya6zEEX75!S25TI!;7i7CS3~ zqD;7=aNQXWJbu5<+F%>FK(}w-{>yIz0hjIQX>Y}R0!o{arN!uY>N0H;V`GRj8YdXS zX`VR6@ZUHsh;}v+o&i;5nK%s5UaM#)$ommvGIbqj7hD&Rca$p-fT?1(NW-*a(W7B6 z?2YDei-75h6%t3c=g)5k&WOffLw)ZBmX13H9tqwo{j7>43^9F)TExErI@Wk@n*jvZq-a7j^L zgv5SL$H=y<;N?r|{CHeE?d|RPYVuD$LH4?719eIh{Yn^`rDr;e(9Er25miYy5V6%N zBiDD+ZxFGgJ~9Fi8(@CGqI86#QSEw3-YNCqK-bq&Zy;4;c;co>#&cmGbL7l^S(;ACM?<^S;W>NGs`kSgy~pPQM8^eT@N=B zOufil{9PHWH+tX44tsIkLAk(8%M`qUCLKLK6sFS0F>aK-^p*RTsdRkEcP^i`?TW>Ly&a^Mu_J#`xT6*-*ZuKEV<92?t*G!Zm;c95GC1vWom_|+ekGc=? zW8WYy)M$o(Vv~;>>w2itd&@UpicBQ0M-RinU6!|RIoUjlM_sZzrmincP}QJs;7oD! zGTq)fw&0pmc_t0_l1ZU8@T7R^zjGo`o_|t#P{n8KQRbsiWy36{YIs`CcVQB?LA_Ut z&h48 zT`_x`?{)Zeon00-*k`HWvrT2nkV2D@p{(OOA!i$xDw(3qeM z5I=d2Mzguh9J%0s?hQNAZad%=LYs zS#Kr?X56?kkyDMj>0@0tkO|a!{^5Jq3lWx$#zBpTaML4_7LX<5!9Av~aud^3 z_T8_TY%>whjAxTwgoRLHy8ddQ4GrEG#a}Wz^<`R;pk)xevyFHuyykA#|L2){|2u{m zw19%D+{_Kfwh^ghD;GZvAHaq}7Y>Vi3Q%ASFOl*f$^r*k7~0o5gUJ!C27LH|^L+&y zF^%&`2t+Ut4mr(dItV=_cms&-de|@w-zHLc&XU+VuZ%HAB>}RK#wP8ZwT?Ll3~*1rgx)p+*E;ulU29gqRI~J+s-V$YhX@Vd2Hs`DME} zO-!p@OHeh~7vkPjVgoarL*-wV<#=p=kwz)!? zAfpk4F9$N0si`fJzX@Np>H*VtyV0I2oLl>NM!Inf<>g4x)(Yeut-gIA`IA@+o!?R2 zMi3M@>)|$Q!&RdTm89^mUKLqcAi^V3sH4NsFve_F)sFxQ`X&%oj#7PVm~zr&M#K7A zx_S%R1jv<;DkoQ+evSgq$g-$&OL@3QFvBhCGP}q!+X9dU<>=sRgfw?2UnEvu(t~fa z;mYW61-EvHhjqb9BE-TV22R`rI-6AG+-7JiQYXq>LuMvYuGanE1yz@_5tmm(4o^k} zxleCkNVH)ZXtrMFs7fZ852+DSVpE2GxgvEkgDxu|VZda}f~!3rUfQ4!>Xb$#r=sSQ!QB1GNaRFr>j|RaKdj@ltsrCx+uvs~;W+ zSJq~~wYy<@(q!E^&zcB+%Ef)JPbSHaMT!ZiZq?+y$e>G!7R0P}+&8BG@BVZoDv-7EMAlwX zED6C^PG#GHsvg!TK@?-izj#n*FVPHY3c>O}nSv3@k<@F?-(y4^uUXpjX}U~M71I=A zUbeqIYIyPccbj!RGv+?%*AjzOr=S_n$bgI-?18WOBds?1xFlfO4nHQlI8r!Xhw$xu zeY6z@GB){c9v(-T5o8#24hJ9jl)SgcC~t{)zOm^=>?D4A$cUp}y?hB{qY#S>B%GV{H9o8w7rN^GEP5LUF?7=t|%P z(>@h#kvgak7YsuzP)P|u>9w^hVegIiCiu*U;F=I)?P-ogjbswnAR+tDiOw+#K&ou_ zmoH%R^AFMP1})BA8lrRJ%UfzMy)yF9fL6TNss1vss&3pJGVt)^_v>a`28gvl*C1Dg z1hYVXgDkx$!86ac`raL0#C(%1}kh@J)yDB+5=AwV;jmC=tTJo?iLKB%)X z{JVgF<8Bf{Tk2SZaJ7{ga9;2QyZjpRl8saMX+dx;@$b>4`3ftsdAoO5i6ULd0dl-n zKwv~C@^yOATow+MQlB783MniqhA@mOT@JXjd!fAdo_7x0mVoF1TUpS$6_ zJ6fJhrn=6s=Skvm*!>G`y`!9Jhs5rl!6_&$&hE0efvJ)EGS)0q@58yowK;NN(v8t+ z_bw(o&A3#eFXIH=VB8&tvx>GAVUW-m_$od9!?-(4hEQZR3|}fh&g1pSTu=rHl8oIA z?W87r?=I~B)Bsyh+&3w>0r13&1<|FSh)DgpYC3dX@vn9KmEE)4~z;# zQFk0{$Em;KFSUpPz|FT-hG1RcGq)1eW``Sc_}k=EJEsV|UaH-Kz~3WUXEFDH>qG6} zqHiyfFYckKAnw%sFU*G_{GiDMdq+;rU; zDac@8W~BBkL0qoMv50O1H9~;wjN}tT_d7Snv^y4^W2cC4a;R8v(dxWKQfMKjz=8ZI zqLtXz6;ozoJMQ36Up4cYZ+%?2Sm116AKcE(vy<8G=sitQhg7*=55_!xs-~%qiAC(> zxz%B4FL;N>oAuY#8J*l+KU2k9|72gkHCV0uW^Td84nn3sKlQ6%_X7?B!gF^tx%I4D zowI+(`LnoK%$Oa%2dpHW^o0sAUIO_cRff>)<2B zBTe|M{W@Z-uJIXAj*bps zo_r!;j(hjEJv;CAro9mFg_u^Gwn@Ub^{z8DRXg7@hZ7j#2tQf)2J1_Z0!n-ItUe66 z{8+_G3V%~Ev0Sl+m^GYcBAf*$RKuFdx6po?8XG^af6_dP=~RIjNM+4p6R%|AmUw*o z^j|i*nY6j>yQhkaCR(jx0-goqH&QpLV%Xc^;lthOQvAwcd1glUwxI|o+NCC2Td&8w z=R`OUM5}RyAXDM%SLK0<9>>hMhicCNcezUU9O=&{a0g5UOGU~$YU&^1VQIPe(<{MoXbjUP3j8`Bx8zltEbYZI868bmdmCZ_lN=LGe-kdm zfldDTV>~oH>YB0ZEh33hEm(}6fZ|X?mPHh~jYrwlHTh^Y*wsZ`06|u`AyWqn32vy& z_+SxMOd0|`3HHs^Pa?PmQ$hXEh#Vor)!v&h`G{(4+e?AaYg1_Aws$@Bsnj`)D@-}!P5~8 zT-wuN{?EVUCME8z19YtCN>eUlG!X|M)n)U_Rr1)U<&)h_Sp(`qDKV#_x3+E6&=tUP z`d|gUBADlKsdLQ8VH>0SLPMI(SgN?17SIJ7;SWzaRniToG4V_Q5R|go=3}<7G-tE* z*~#(EF%83%e{h2TxPOB;{jNn*s;!6jJH5p@U^TbwvAN|Hdx=aTwq`M` z@$b{c6U*3$vb;H05C3u=@DNTmJ2N{tzV_o;>D(D)b@(gGHgrkBzudR4tz9%Cs97ml zt&lgh*zR50_x`J%NDk^}>dXo7spIJe%{2OG#MzEZE7aPlP4r!)EB3PXe^vd6hv`fD z^Vqr}SGr2$LpwtTq$2I&FfZkPGx;{_0Yh6``Hqm6#)rh!K2DeQq>&z`qK*Co3JcPZ zXs$^gdiB(Pu!Sj5aj?X=tw7|=rBAp7*^rQpZB=cad!*T8ZIrvHPamPZ|XNnJKPqJYUzoy`ZM5(wsZc->d9#uZFB!*36zUtCUr8`$hGuaPP3Qt+!K7ot$4v ziRn0>;WuVx+g0e$g|`;>C&Q1YVs+_K&zIE)CzDUQjgXmr%6r|P=sNZJbr=JP#vDfI z-jwWRF>bq79A}=GK_k{VH{qL@-2#%EkpGTj?xOg?hw6Qc)>KpsVX7yTbP?7AIc}`_ z$)kID+e8dnmL-K-lY5z0M6$L&6yI*Std%hQa{m!3DxaYy8sonCUA5Zpo8KSQ_#!G# z5liXFjFSNI{?^%@QKQq+k?$T{w9u;TlVm==L$Xu(r_m3JoV>h*Mzb5nl+A?rI{Q}} z;?%nHUVz!Qg>LO>%8Zc(mvm%`Xi}UdYM#QI1*Ltz^bPb?%9=+X$H$MpDmsqvIS|=C zyB_T-gmtGmxzQ>c3_8P@X9Lz?%$?a}9Dh=UYF=+&nm`Pqq>C=t!?ye$AVMi#oA}GZJ1u1eJ%%AtOBD> z-PS?}4*44LtvNr6*14-$M*g$G*qugM7lrPRl?+Z9jF%aiau%HIb}cfR{eJ1@Kc$Vv z8{|;#Ps+kx=0~?k-DmYo=nNj!yZ_X=p?Ul|x9(Dl>v@eej(RhwD#K1B>k6-|E>CoB zlZsXMPDp!qi`1nuG&w62Dx;ke{M+)&pNm{4;epXux-HbtO)TE!m5SCC6s|V;=|b>&K195Tg zD1qZD958K;Zff0i7@+cxE{iyuBrWh~Dkd&jLVt1IY;cwp|Hj(5-0-ukc@0^N{e^*u zcf#-sHA>BNf|UxA*tcDnVOC$1x3T1O&&- zfXr)JMocJfqp#U_tu9QZxF5hzLKi>p$Ki;&a8L?fFwT!p2sBO1!*WDI9;ykXMcv+` z<=}#_St6BZ66XT+RT5_Cu4}Z`#x#!FMyFAx>S2{)R-t=v64p)gb!~Ty#48~-11i3` zMu*xB9kM9LrDY@EzC?AQIJy^M`c20g#}|lc5x_!)(uH}HfC3ZK~}4+=0E!*`t#FX<)5oYDy$WwnfIR}uQ#PhOaEmr z*Yoh4LQEnv-aSO+5kUcq5Y%au4!hXnUOkEV)zh$p=*3@eE4!#q1sa%0tO-{yB?bOJnP!^61@Q&U+DwkaB=ed z&GZl?zP_4%d*Q+KsanK^AbZn)@%OWAmCX#DqQ^Iyi5DqR+QHlPNtkcZH&e;E&4QK) zh7O*xtrls`>AeS^k0XEZvd8OQ&9fd}Ua=pnUNyyaZpFk8Dv4;dHs0~nl?KO#zY85Q zSNb70hp@L#qs8K{2%(zCbCRO?}qmcxmU|S>Ep2f?gJ?fCaCG zfJ~(og|$EV7FplC(5oQFg!0VhC}Z2bXm3e6FqIYb~h7;bVSZ4 zCjnisU85B6jWD%nKk@n66mWig<4>u`uXzG@-5It1`ZcZ(X^lx+gv8GR2cMS)Wj^zo z4SD{*GY(gHO(U%O)O&v{ErBl~J70p?h}niE2(cj zMV`^Y&Bm889s+Zu?{%q?7%z(%k(gG0*BwR?9IXb2m~l&~a@$HgVd3ab|0XdZONmEz zJqh9U7rAc43Cz6LOau8q*U;?j-1X`zw;ogh!m}EEEGAD16`Kz{@!FCwoK?j0{J~$c zya>cS)X8wRhbd#?)9@8Hg!;xe**=w&T{=0l>_=%$;WdMcFW>1u$V6NE>qW)@_zNhj zA*34Fwk#Vx&%W%K6FVy8BJqgASv(5nyOH6!V-s^1^m#>@k2h5KG-)1~mz0<&KhRUq z7FT#a+qQPPDfqyFK7Uey)pNG&7Mr%uH^H#gWv!SGOJ6x|_uHSYL;J_Qu2Az6&HHA? zQ*@XVLB7QL0-1YDR-eh*_Nwhq-zguwP&umq6PT#?0^_f8yKm!>w{_b~RxH#G%E*_h zJ(Q969JE^XI&yaZAOC&1e)6RgzFWo}ok|b8I3ja5UuEjBwcnM{^Uvd!6XQ9ks9q<^ z%$y=pSC?-G;C2ZxSWiCV*Q0-pKfA%P9&T?L{D0%OK98-b3T_Rtct}i7_(ptjZCHua zd5@fYdh1Y8{d`HP z1IjltNM`=0RJk`JyboTLzP57Ll(nFX z(lz_fAu6gnmof{3Ow7;z$)cO@tT^v~hMMa2E)&Ko)5Shbn?E)isGJo3Y~D0i-}f%a z;l)unN|z4i?91Jk-I1YhTaNdYPhU{Q+u8j=k&G$z-x{A76%%Ap&Rr?8D^EIh2y+Bi% zArm5hB6zFkzRw+<+JO@rChAR#doo@y|Gn?t;7&wzhUkvS4MqFNsU8O*aHL9TCZO0yg;Htw8wdEb0e?w$4vc?2d}gZ3=P8GzCH6qNC=}s zM6}gQMNQ4b$Vg1b_qF<;e{m;eVH>fr5cuaWj#Tb7|NMc9lztcC4@N>>gwZ8>yMyj3 z;s2ks?}nTw4%|Wbb88l}59vw5pM-V^-zWT^=+J+D#BoMsFX8WvkNQA2@lgevr$3#YUFP*e!mXwrakbEN)I@?q7oV19k^VhEc zYHn%zz*7Rw*Dha9FUeW3X5-+9WYc*X6x6c1Fd9HB*xS>IuN`YDrW1+2Zj%xfb($_d zKK}mJ#;Rt_4*RasjX$-nuCDH+yu2m&g;d3qnffM$dw=H!0s{lxNx8WTa&vP_>}GB4 z{Bd25EY`V?qmAbJKgY(#TJtb4BxYsFZE6a{$WZLwxub$jGpcB2Xl(47ar3$2k4`Hy zk&&mC>{gw(cxOM}3}w~&Xg~7fWV&6E{u3fvXCk7) z{GA^74E&UC5yS1?S}%-cm(e0F7ty924{i3 zuU|!0BK{n#c=P7XoW0BkR}p&rFmPskX=y1z-gw{P_@tzcj9qzly1H?$wB=IM zmz}B`gDxK1&-)==ka=UvoSf6cH2@FVY;_74EwuaKEg$6R|9mzkL< z$)HDhD^|7}^Yr((pbOHGvKNolJs2DrS?;!KSY2CtAwnObk2PbrzC0b{x;EZh>ZB2V zL~uivDKV1Uf~J%qlvP{cL3d$)|Gjie4|jJ`;j+!;U$0*u6Rv5Dy;EaU;4KQ>Q<0m~&au~_pzaO8T zo*o@dLl++xN6*e~h3EE$95u`Fb<$apk%iTkL@B4`PQx~p@Cd@OZ|~l%SxKke_O78bXbUX7zvXR+{`U5R^u;?iRqinEYF~coL_ZjVoJRe z#Nc;|=TGTMFVm$<Bd_Yea4bp0)GjC}2o3GP@D!MJ#vLy_a{9XU@>J)< zvs66EkN+HNE3@UsEs5FLSC}Wj+!r zo?ubFnO{=ETj1F+`m2j@c)GjeJBC`56(+|0e0|3j;Q7v_uW}XX&Q-9z6AK`e-*yHpgxU(iLkK!Q|}}q@)ZMR>|2v< zcH_nkW#tooQ&95(Co^qYEH6ZuY31L?KQd!N($;tho`yvZVA}h-F=1j?~NzcuWtf~DT z^tcWtrpeFG_I2lXc6O?&suugQkrVIEJ^i{&@#3+92R(yZKfNjki>yWuQ?kA7G)DnV zl8LDwU6`Dlw6e0AoR}yqh71;);q^U!9l8 zb#(Ei6^!-?Fz#IV^Jk1mBt);5fe+Tu2xb~!XkAi^%iL^ciT2tDK~1OXAxkzbZK=cK=48D$8Rn# zE=ay1J#ivYPZ>KybZ*kyQQ21qC_Z5aP6@B%cEy+Me^E7 zmu7w;3_Eut<}0RC6K|!Xqa(Y)4@xdGO1g76E?yjKPLSZ(MENmx^bnqJSGQ>DbIQ8Q zhncz-G&?vrI5E*Bp5)(tl8Q(DZ#fcNHSon)v#bIHczbwk7Za!Dhav1*NaUeK2 z*q!w3*>_>=hRut3E?2Hx>B!JL>rU+Pr8CFy$jPhjq`DRs+ZziFv%d;5JI^vOFg*B` zpZ@VC&~Z=N-K@9-mi=mRadA`A)8m%T&Kp&J)Z;Oy?Ce&~ zON2CT3$`T5$Vf>IZ!ssE8X5f^^kID!6cj9;aQmDiPK>R=XlrwGvmr*XZ>{P1A?jP( z4HtvIbN4w1g#*XK1_+(`JDwCX7G79bs99n=E^mDL z{SH07A-Q+Anufc&)QVA~7=xuFxY9dscVN}eydbq4`4N^`w)TPr#i#1)*T(JRdR4yO z0Ra{k(S!-%jtLPaF#DJ^PWGa5rdC}~3JWXi_Sm}sqXHFA9WB>Y{RiZNg5_l{>&?U1 zNxUcTbT&4Aa9A*)+9DG!o}SjLKCET^ z#bF`A!RhJ+{`mNq96veKMZ*ORX2KZ5PH~)R3{6$6ErBZYPh0MbCw=VX_Vw*{$ z4-&+8WMmrZ>l157Mn>um-)wau&yeHc<#n2DPuHOELSeRSTf)4sI@{KzzU)7+u&Tkn zfyGm{@ke5Yf#{IXc(x*=%S_Q|cuz&BwyVQ<^Z70>b!}~0YU& zcB|X7i3v>2WmbqmY2i-|6)vTAL8;QvIF%l`e0Xg)(3RC8r`g)eSI-3K)F_W zzVE}s!-Fo!G*H`i=} z#YK0CNJ!Kkrf{!q#>nN~VmMAgkrW@_sC<*1oxS+g=pNu*o44wU)Ne~3ip!N%RVFskW!B?ln-Y5W?%gvn5$86x z=;W=ct@T+~c(gbsm+p(XcA1wqVp;v6jg7hO5w!y149${6*^5RU#Nid~J!LMxnwkPh zn+q(o_w9SmYd6!l;1MPxtSU@%QKav=n-g ziAm!2ZCSz%_mDe&!@|O( z$nP5%G?S`cJl6H=SIfOGU0va2>dMNAy$P?nN}cSOc0KFv?*46AgpVdCUovMLQoG5! zaE6;j0#_x~xA{9rCbzxO(3)KMK~Hg;^`ApaDKBe>rEdOCq!Dngxjgw(xnqjKWo>al zSxrsNpYn>Z3;EJ#|B#Rn4v_~K7|+xz&0i8M7e=J0=xc{9J$1sPnnj|7W8Y_HChdr0 z;S72CQfPf{Zgqpxv^_o{ApzAjA%XLkPZt*CWDnD^{dp;Yq<_Ym&Z>oBhc>5KBQZ|i zTgaR&R-Dds`EuR&?{Qf+_V$I(pNq-M10!}VO?E^?u>$qKby`{^4WrJNA#O4;e%G2r z+~4)9(4aoiJRuzn{mr#U^xQ%dSK<_YPeh-0adF957s*M;z>%8zF6j9C=boBEo4z8V zqIs{HzkH$R=PxbQcVj+jx+yPbG2rT(p;7dwO+~vdnqU3mu|M{Fz{WU)Z;Ogn5_lC2 ze~N2*$|x5H$7`o9TZO11yZLi$2P^U}rUs;#wHU^@WI|o)jRkX4Lv#Zr9#} z4{dD|lb$t7}a9GY04h;M$ui^P+qI z6QK{dTPoVo(NVXJ`RWNN-=jy5Mx7E>R8diplcNjzX8l9{=FOW94m_7FauPyrpX;_y zw$1x(6TNoQ+uPfCjqUK^DD(7E{aNR#v7le!Tx3>+n^BAEqJ4+Ov~|Dl%rG zmP~i$5q395Mw36kzqykFO3%*9EU03&sjEvIU0`zCD0+gYNb3)-!vLfTs*Q;S^otK)sd}?2TDkxX5#WZa9KRaR+GIsMB30=3| znaHT9h!ly z4d)Dn{oHbAs_X0hBZe}A8Gao;Ktj^FNvrl`2oyghC1osoEoZpUgyg^h)9ZSB!SO}D zM7F0U>L13q!&YW!Fq-DT5`jD+9?;Hu=z%(OH<>oHPI%%oyTg3Qvm2hUKF zcp&=mJxmM&<`X|bynCV{RN&Fz?t(kQ!aFJ4J}x#(EJ}0t5D-x}vt_sQw+X+{ZnHdh zQg-$ntO2jdimtG`t6-jAoWyV>3t=Re=R2Q#(N>UQ*IlM9b6D7+6DF=~-l zP^jQf``XYT<8m=JIe9!z)MqSPJYfykv`g+@&wWF~mYJc06(8rm9GKXQK7IOhk?Xe! z^EqsIXAha+Yx-Q*&6h6RZs7zvjp8y7U&~O-Jdq%uytPqNt(?!`Ps9IxQ%dc_y&a)e zz$y;I$7yM4XU-RVtBc?^9{#(&&U@|J`&mCw7;U%ht<4Le=1_EcYz)PGv`QT}F(WRQ zI!*n&LUpRI@X=WTfoOs3RPAzC!dCuh$WuOfac-_Z{2WAuSQ`2q<`S3nWu;W*qryMx z>N?U?g8~ER=LRaE<2*=A@9x%MQ3PpsnCsu+m;dbP(|oheSK;AZC3Y4?8W}VvQX7og zwLs3EA3hx?L%lpduUhW9iT!)`&L=T3v>p9+6RnpXJpBMf0ElG_q~}g*YB~v_hMN1q z>(EeM9-c?1_ok$#5~$K)!{my?r>+6g-!g2bBRy8g+P@m z0E_~mpkQwYed#J9$Ys?5w4)~P3Jc#vB&>T-Xw5AzVf0?%v(IZ|*Xm z@D?#`Pb&i4vb4;8^5jNenM+Z~62Q%co*hscgv5GI(z73Y*Y-#_YAtFIb0LSBnubQ2 z7`&osf;TWM>?ApP%c9&-^?Wmj>Aa4`#YGYNqfxFtLE+)Ww$o}E=hIL&P-B>s)82l5 z9vFC({{&q;*oJ^y*x0gHp7PYM0!UtOZpmE@^L4V^o#(dgIy3&c6+**#vOCt+95D?Y zI2gZRECqmpFK|z#9T!iqtPAHnaDw?QcPG@ghYvqEK84Qs<_+aiCke3YznWt+Nw9}- zY(PK&i{J2mk3rh)ztw~#lhbWxW>!c|-0Dq!!F6>wbU8#t4C+97N=jVA>yU5vZ$3Y) z#yFD+W{LJ8^ij@SSU5IpA*(zQ&jt( zRxf-=rq!?bO|z zAZNf$t#VhOt8o?DR921h{#NO4cXvnSNEX!b#{&M9Yuo}f1a)9~dfK=)OFZE_^pYPxxDR?OJdY01YrSyg=rS`GIS& zIk{cW<_mc2?fng4NIl%(4S@^j>AjP^_-A1urrTUwAO=tKEGe<2m6Z(-Y#xjS7boXs za9W6OaxK;SJ%k~qhUlYCSYPs)(Kn<_#9#v99zTBE`azJbz9jRBU%!4qOkH1@wR;p` zF6@QxD=Km@Hcs{@H=48mh#7!=e$p}+qzAQxn}g%n@#Agn?IxwED01fJ=}Ac{ii)iq zHLg<_`??=LaAn3aV{NG_E!$z^p4&AYSiFnY!7YNFJg736T;EuPWr29P}v?a0^n)eiQa&vRDvdS<76%`ddVqc5m zv0~uiDRNpM!EUYNyp&`>M3j3cK|Ucq-rm9CBsKMdJs&waIp`AZx@ExA{bg%wt5gJ+ z&fn8SM0|djDP)bF;jr!5duW*tfBV{EMP{A4|uQ^Vr$3VDqE7@WTi{`Cq=%#&#kcO=>+deIBB8UJ=^p<+RJPQv z-XO3(1x>5?&s+W%*ij(Ukn?|qq+R8II^xZn_p%phIs{I7dV6`j;dfe=U@iSUHMKDs z<2I4q2GxP`^l2AVgBLGeB(lc5WO+Z3*xl1Z!J=Vw@7_afi!1`x)b6@*(Aew;w=Nvm zL)7?mV#35zXLpaK8L%AiY2aBgNbKaoCMG6VuU;+N-dZP^S`x0sk3m6_l1AY?75^=K z@ZbUDG)?EmXubMK9&99

ex1W4h(QpPxN@CNy_VP>K}@4Hx(|+sx5%4eQ>z`#a?F zmX?-JpWb$c$_RI~K2)fxs)~+|Mn%$eXyhX4j&7pP(g>W3qBmBt8y(JCjbT1UBotA6 zK_nVC`$4AyJggp669Oq#8jMtEOek&~moHD>p>_Gsk3dFHYI~wW+pVPz?c1{lRAqC%n)ZTBv~%3P ziF3+w)r=;cYvI*}GG%)2*^I>X*0QQ5#b}2fsM8;G4Yr`|CN0pi*=EpOIJX~zxnOhC z>(ue%oMt~ippIZX@O?*x^B?q3k&$HstL!;&;^XQ_1ShA=a!DcZFSat7T;oH<9`D}0 z!-lXH4wAD8<3{7_M(#`T1Tt?JMDL$hSI4KC{KQdNd7hJV$(tr&;Wf{EULT)W%~0B* zTg*dC%lH1U@>jX!+g=Mhlv%Lv-Gza6K5yY2M2pUv{2C3dt%Ty=*4C!sEJ>`Rx;$9v zGe19%=Lh0yodAd~$v{v|h3)L@(i^KA8yi6(Yiiyn?D}^9D#|oUh(Jt3ItLru%WwA~ z=UN}R8YX%1Sc*j@o392-K~9cBU&?W30d1DhlA@x#S3z7$<1Jz=p{B#}#u=Clx6hey zJ-N$li?dr>n_apQVRYV9DfGHP%q(V}zCKRUu72y*E&1fvED>R0OVhtpChtS*3zNLH zz;JB8riKR8CjJALhuxPy87t@H(ZLG4dNQE;tb5IUT*s{iV=&cV@rcWc5^XFeeDJd# z7!~%8j$J>0Hu4vE>rg_K$9F^6c#r^jRFZ+*r&8?_KBrSfE^lrq4@y7zD?86e0_2}3ip?y2Xup2(t~y*u05Ek0`PTMdudtP`>f zI(Pd%Np}2Ls}rxxXUe*?=o{^Rg-!pjqdRGw#aEoZI8EsyCm*|5bkTtUQ#UV4{#`iD z+-{ureofB&aM&W9q?OGtv(Gz-im7X}HG5u6O-)G>^8mC{&{~m^k!fgZ77aXOQpvD> z_|VJSTea|E?Co>co&V0&NAZ%-VcjbvNHd<(Y6O1h;|#aS^a5f!5Ytlod3rkc>Y-A=D;*sj4wlUH^qzckbr1xJ zJx9Goz_^Kz97&qrU7P*r(IYJ6HkUEFn1-0^HYT@k6HhaxL8Lg%n{V2|vJ@yoA83Ty zg9Q(@*!?B&4=M}|pZ!bITnOTBn@b%%-Q7mHit6epV7(|&<8Sr}d-?g*`JQA)O*FnC zE+K&}+;+U#Eg}t8+J7wAiZUQRK$+Bl?RSPIO~BCQQEI<}Joxl&4=E9idfam6esKxC$1x%oDiUGc;>O4It?= zxWn@*O-V6nPg5Nj7)Vu47l+K6)rA6Z_Uu^@M0WeR{+BOF)v|bac)&c)pFf}88lRdv z38&3i9Hre|v>db@@B+TtgR?(5}yg`~3M~(B$^625bq^c7sz> zAN6Zb2-iH_bKr%Kj|~rnxk>3?m}Ltib?a~hmZyJ-wLtd*+Zi1hDO;%L?UoL?tUfX@ za0MF@49;|_2MfAyA(3tVsby_piY+d71d0S9d7qq2m>bHT&kvrwYTC|R%xCc94dk;^ zg09ZL^3AOeRj|RW7<4BgH}dU$;=_36%)Kh#lieMna&k=sIqKD`xOJLVl5VlFLoNz| zA8rN~29bU#IC8>o>vrHGzZ|3crq+4q9dzz4i*MI;Z zlV+8V8@5kJZ}+vBPoH%hi1FjrrIt=xogH|pebiort*VrS&i#Ym$#SN8m%4V9@k8oA z>`n@|c-=~;cWZ2!tCOL)qyjuNApmTH(wVBk<{LD00cqTAK? z?AcQeCmkqrycRn!aOTUGFEMXXq!W{p;DM>Bs-~;ueQ#`>!*lfy2;hcK_2yBIi&~I^3k%y<}7PdK5eh_y_85y|B*ru;P?EeJ1e3hM@mful? zI84CB@&5ha!AvT)v%Q*HS~LCqq%<^I1i&2AT21SAij?%_>(`3_vIh6>qvpD~xdD>Y zhO%-;Xt9K1%hW7)jq+(5L4DBu>V+vj)0s00zK2zb366rHp;13H-|M8K*#E=p zf)yDF`pfjx)L~kI4Y+|yDk>lkN2oY1^7BV|KGV&z2fL%8p@APF@UUOVz##G2vkF8L zlo*fK?XpYsJb3a|OpF^`Or#7zR{;p1Z{V8Agz+T4eED(;V$!KoFirvh(D(1&>Ez@@ zOB;hsMS}btViYLIk7WQsAgfTe#fcx;*qpw4zXn3Ladkwy>AT$AMUYG28EliJ;>gp`oFO^jm7X zP#%qqjqx?55S5UPC@(LMj@E+IiNXkxJTGqvtiI4qwt;^o5*a4Y=`L%}Adt4HDKl;D zKA4gaUm4Dy*N5}Z#Wi8clKcHv?;v;(%AaTT&?6l?dwY9X*_VE^{?=bHbxTA|qAyqL$xMl@8D{_8yFZ6(L&r~@ zF`v0{py5mKEM*F#q|)(3GGTiLm)#jR9Y$C_YDxQzSHI?#CM|p1QR{6%0YjWr?h*;z z&Y01qz4eAnvs_xMq7Dnk21r(2p6)*U8K3Sf7SI39psrHMZYMYLm0N4;j_S)F%Ar?x z^6T{9zh(BNU|z1LG~?sxOvc-EeoMmJBxX`!j-mR|t>XOx@=D?9KUOE3kMF&8;HKe+ zBY%#O@V+kc_I*Oq7c1D%ts3U=BHiUJqon><2H!xB!9NY(7vH7CO~%LGBGiUb9Sf3= zRVgzG+zeB)y<0(J)CV-4&pWOt^A zT-U-O#k8k->LgTlP`EoCE21F;R~?+}bPIOB7hLgjgcZlZp;6|{V_xa)mbOm<4Gwz(ct8qGDp&g%58)XBz&Fn(_HF9Y4PdG(AJ)A4ORNxNbW#e!pxZp?$~6WwhumRo zWJG!G;oGvZEld>GtyM$e+wiJfUAG8wl}sT^8!u1QnP%Zb9n={* zIy#67?xdbN@cWPL7Zny3ZrRlMXy5eZ%V4mIR&rXJq6J<1rbcdj-Jl@0epFH5NR*G< zsn4H3=jG+;W)(tQ#^yjx_qDA}=?tkxk!3>gt~|>tULtYt-(L*T7yR45Yi*@}al8dO z6cCh=tqSfK2Oo_Eo)UBgHK~#P9&zjxO3*=t&isUDHBj-KgbuXueMW{Wd~QUAL>Agd zXso|5PA*?IVA@Yi=hty~vV7xoL;R8Qg6G+1JV@yux`mK7KOAJg&X;lJtg>74t!NMG zzA6=mQC&UJCyxin8{|^5BddJx`t|(zyHu^hWU;A4=f*jf|H5#9;S`&K&9*^L8Mm>I z#ZX?wT2tTL`tJGZjo?)kMGs@X5}{}%ZKdAx`Zp|dc#qoqliB;rk9_Ue&VTWNnJ)gy zLcGxr!N1F&okQs~T#4z-L|uC_EnZbQ3f~^`BOe?upF6!qPP3i4y*K@YMn~};B75E9 zV%DqE3h$jP-0|eBYEfX7m6d_0ChU0dpu~vh5nf~!T>Tw{2ff6be&^1ef24C||3NrP z#H+|p6>&mk!bttAa+-y2A-@LI5P)s4{dI6K%Bqh@T!^EVN$g1|G29eCE(MEIjl=bQ znlZn)c+P#QuRP{gN_IBEeTF7I*_KK{M-atv4r`{K1GPGoRV^(!khH6+Td{WPV+6I4 z$>F8^ySApRs0b7lDKQ750JxRJ;|r!~>a01=)NN^Ld9^mXyUvsm#)Ico4Ju*cj-r;n zKJ8dpIEu6yPX)I(6p|G!4fE@|8a-Irq)!78M5(b|7HFQp>SiGuhn)xD zd3|Y8j2xoWQDLY27Iqg{8P&EYC?=4w>~mg_>bjR!OJg1T z(ILF5`t^h~c)*po%~Qmw_VHV-D`JWXiW{q!=|uWDZ}BG#{$b@1U~vDysh{$S`O93P zK=o;sYeYgFV-w2vscdtva!C0&TD|r0IY5Wl0HQ8_>UT=mQ?_g7S2qy)q@W0exidaF z+2Wd-ex$(i(IXwzcEI?(9LM~QUS?))b$iwZq2SP|YmJE19!8)s+hjrL6aWNp4t$)Y zU<|MF#bX+eC#R+~6ciM+v@)VSC(nH!hea$T^cd_uHkPg~$KT;Yb~c$$rAEVv7y6w@ zbG~9xzkBy~LNh{?>HBUr@{XuIDO?LnNl({TRi0!v&6{d2q2%y`_5yz#J_l02Z7oQf zdim2RTs(HY{nn&)^0Glw$waYsfjg<=F{%uG1A`M36i$wg;t9=AOAyr*_VT~HsaR4k zDq($T3(x0WT%5DBGoZr8#)iAF4pJRLXY7%&f(;=hB?Vi=)WoDvi@eTcSobY$TlErA zl(d#2lC-C>sQ`Ltu#`KmTI%ZRBCT@cv-Ut5i&Nv32rAZzRN&#cj+!r3ze}A{%i_FN zg6&ust;mCU=BM86GfQU)a5P>yM(%vTPS`7DBY#bNILc?*ByV_W`JCTUioD-RcBS*G zBFX*VrqnuLrl*4V?&J@jl|FWg=q;1AjhFLmFF(|X(-CTUKRY{*3WGBmrKR~(r|9eG z%p=V8CL)63*s)_LP9QOq*ExOT;<4s2Tln=-edf0HJA`YHI72|Sqp|TNMiM0C=g*o$ z(k4N&7pK}%I{-!V+H;uIH8st3b-%D3jM@{IFc$f%AzQx|Lq%YfFtQN-I?XQBwD%9h z)%@3wJ*A{*f1uKf5?|ruH18NtYf6%#ErxPZlQ;_{T181IPnFudIiZVo=(j#c%1A~G z;>Lc+?k(qyfh|G{l;~-XXj`?4)tg5M5s|#%>1ioqlEa5H($b*!F~l?k`1&3^cu?~8 zZ6pq);oB!v+?gE9Jw7pH8uOBi71br1fR#0i3CZUpBqW5b61RC^|9;oM1N$6JtUGUx zF2E_dnI+AyZF)FXEDU|Qja|~r`~qW4W;Lbi`y)l`M?%t_ z2-K#6v`0WhQ}rRy3Dd;mJgV{sM?3AJ^J7oSDsN*92QTy zYK~__&}V4%4Go#W%$JVsUyJ_?qA_uIKe#D`EKt>dJ`=8oOylLs7O4P`m5M}f2fk=U zMy#>>s`+51DJkC(4DaNZCBGoOI6u$xu%C$g%H@D+>Z|v?{1e^^&ZVDvYJ1Wm{IRs> zC(e8unnN=|uU-)hvykT8f5)3Q$L;c>Q1^BS>Y6{Tpa2I4h(#(~4eb8Yj%qfe4IciL zhN^HOIGLFx8H0gawlRRrwIDc%Q!h_;WCA|d`PZbsTiJ1){Z5>}bfb*vVd5d zgdjNx7uaqHizXF0c@R|DA)4zlN|vbuU#Y&3&!w0Ct$(&hOK(&&LJdl^#pIb5%*Rt!CknZ64<*B@}PyCE(kVTR$- zO4KJ#j9^-*SB7KvzD*okdM+XX(U-wN%as|;C7xRAFy8zRA7tqR8KqH7kE22QTNc{!_frM!q+IX@vNBt!z|4bhVP|zvoy8Y zNEoWA`_b^P!Fhd40__X-t@6HIU6=n}5oiP0@X}W`G#piq9_ThO9}^BWnPp2iv9`3N zcYm0kzsPe2I5W4yqQp`4WwkYnHS-p3gVHk$$#W6<*wSCBAyVSva9(iesfO6DPE%fY zIvuqU*b+t(tsYk^$_XR|Skm$iLIi;$!jHU1X#h5IeI_=DGE1_5ldtc}Yf@oFzwb|9 zd$MbD^?3D#r#A^nihX*kPRaiX<= zzW?C$;Z8R-J>zoXTjhSdie?pO^YfoxVH*zH9Aza9V4aJ3AAKuqo>%VDo*hB0 z(YwBk9nxez=-@}KB4%W9?~%`Qh4>i=*86_MweXP(6Kl(4ZgU8G{aKDlH>t{sc{_e+9xV&tF^6gRe*X9|1 z`fT~i{Ji&z7nTnn-jfNb2wDn=%ztL-VNeA($dKJXneT$87_!S8S6c3*%|T(>-qg6zNR#=-5n# zq^6nz_S(&BL%jmAM1E>|LbwO@X$@t&&wj}7IFoXufJ-rMw53JC<&9a+zcJIj2PtKh zl!AkTZeFl5Q;48w191dO9j%L?ICUyAFi^+Bf|iQvdp0w-hbTIYnDw9053A^uUs@aL zV6qxiI=;u5iM6cuUVgr?!+K_zZ}6UPfF~sfkGJzH_LP6Cd)Te0S2g?P@twH$ZkO0~ zN$B|01=IKHwLNnd{jR0+>D;Eb_9ctRGY2c|PJW*}Z}IVwuu}MD$H2@v;*F{U&$Nf` zqz4z>NqME-rCKkL^wH(X>@}h7$|b*ot$UO-eKwR0-J*eu&#ta_>oNaDP|`q85Ay7+ zwI%6JBB5&Qn_e1na`Rhw6lhGtwwb$hJ6REdQJ3w_6?1>9a)jjoULb1ogQ{Z*W0qsP z1|g(lXAffd`O(1&1Q*KKS$cX!MkSc}0fJkU;Pa?yhpM1S3py>E85xbizEhNw3k(XX z$ey|o<4gu~%cS-FC51#%VSr12i17e51~}Zj3?4PI z-A4`_cvqyBrCSM91iJ#Kd(hLoWp0|1mNqqA(1*ms%uBkPMlQ%Rp?4u{T`DCx*hjg<5gp-JR32p%SCte!ckFnom} zd=VNmMn~;XjSXmL8V&O7Br#-vhHJImHf_45F?ZQ&ZQm zdf@S{qfY?Kcbt;Gz(Ee(RKw4X2)uzt3poeA90?t0DK3ChCBb|qcHtOBxqj*kj~zj~ z%UBze8ZRxdGS$1)DQ}Nj1jH#nn&-XoNkpeSl-lBJJ{dg|ncY=g@+Z~APqvXPZ!7g( z8`>rnO>4YOULKZykq59Cv%5ps^Mx@%CdxyZw;+Cw(?n>%Hm72b8z3d&eygz8JUm1ZNn&y`BJa5aiBv+bhSD=DH-*bHpONYs7_`GU9f>3{S9YrQs?>#a4am$*JqIL8Fl!>XGudzMqOC2m7p}CRu(Dc! zj|6zXwz`VIjuTFgef7b4mk1SOiLQ_D5xh%C7{=pPPkG6;(&*qDy+iM(gpy6n#c=9@ zPvo!=I#YDx%u=Uz!h*=I+dhOlb80{g%Zwz1PH4nmX;I*#M)D5@lSR){$f zJK^7PD6k~<-JVZTtYlyzjF)$W)=xJ(Gzv!(#Jwi}L)y;s9xN3Xw+qrCJi+Wdd&?K9)>~7s zjWU?mGt*0mg1kL^`Z7^;A+|YBoma13UFQp2-lqAdx6R1ROan+YX>4ZZ6h27$wCvv& zW;G+P?zVu#|HMef2L79V=KnTV@?Q1a({B+gRzNQ}YTHm%Sl?c_%v0dh6GcJz$={N(BD!oz?JJcf zk6C2qSzeAa6>|#`elo}1WOeg*a@i_b>?;-t_wV>qPXz5hzY{SQOP8HvT$dxg($>K@ zY2EiNKC$G_KYizJw9ZHg&D7sy?>$tmZ~XIZnWv6dX)YI6bzDhys}`(S_%7&qgie7g z7^8>1NX=_&25CXKsbzf+bd^kWOCjP0r43;V+Yh<9XcwC?%`7E!o+2t391@aWD9()5b!z)#LTOPG+l*4B$KLl_y0A1q%+h#AwI=h`(?ChN`{+sXvJ zy}?YB#!P!eU#`u%G*qN*H4o zHwsUpUt5iV#ed!pWFX9T^zckP3fJOcu2oi15p-M} z^X8fdZznVrEOVK_LqK5s4%4)hT6Rtja+j!Vx$P^!m-KeRn? zwD|GD`SY+|ZSAxi-y38`)^p5E0Be}R>z|;FgiKsgSUA=juLhfTec9a5&=6^Dn7qxw z_*%GV?xbjZrwOQb=|ysZm=4Lr#RTOY+uUWtZ^gz{(`!z=TJw&!QFELMKS!45a12i`=v(fCgI5lo4XuS~ip`Tn6@kN* zKYxJS8xh3%?u|Umevb(zP!vL}W~IZM>+024-|qj@T`@EB01j&K&wukk!S8AbiJK}m zDjCbz=k4E*#54f&=VfO%i~nHrm1JnS{zB*sB0_bpSrg;qk1Vk0t&Za*0Qy(M%E|=N z$&Va4;;Azxo=%mh&RG}zUm=;}DuR482+<(p$RZ(SdNY*cpGE^T4w|%mxDJmNO@Ep! z)WXCSY>@C#Vt^ggEb?4?)l>M>qkeUEg6B|EQRTmW0<*DZ*fRayLoo{C-3KeKUAqQL zpVPulvf4;wI?P;>6>6spjr53NXW>sMxtLfLa<1mPJ1 z$Ukq{0QD#=Dr%BF!^vp_{*Lk?wJ-xM^_BANbepeqH$5M6lUBnhN2G<&TKX&h zKbapM*NsBJPrST^?l;^!lKuPD&_x3pi=3p>gCj##em*Z=bfFD|(7zP^`0-<$;_lrI z&NUc4V}$1AJ9o0Ov(*Y7oYov490WfLf%sJELymcTU1G(eZJF&+cQD zQ^g_i4a_Nks^)T`I9{$}kvIL)6CpNuvuH1Goc)IGaJ+avS4?vuB5jMx$|&>qvKvQ9*qY zu0f5&O9=)iCu{R$nquzFBt2)kXDF~Ph0DdfbASH4r4jQ63lR;}?H!X{Tpc4F491i3 zLCgCn=guv@J$d;fbBclKvhdbG%Ji$CXhDs<^9Sd%oNLnUUbZD3lbG$`ixyPs?d^p$ zg$@wJU$9a@W?j*)v@!!`fnG%bn(x3se0+!zyorgCA-=-S?&Iw}4o?O6Y#FW58X9JX zh7@FE@`{Qg<(nSD16UsDXM->=P5j8t4qXQhAZ22-E+Q*J|4kUFe)@*bUFNN7!L5P* z{)ezQKy%=6p-T`i1B%bie&u=;IizpHf3R|_K9kekhwTliXbtEGF=k>q$lR~L-G>#A zAfObYeljx9(&Fi&Vq*s|pH9)z0&-3b4P`zm8bW>1z-ByYfuo!vfzR>*Me5kA7bO4}>1@Utl$0#)cn16pu)aU6K|T z$E#`%(}}Qz0!Ba8d4uIoBk1}uKfk@A;w&@s6hsN|Y1|!}8}NpX%Z!X&c+CW?zM@}l zih6o_62?YLp3k2{S*PK#M5ed~I%r#49`wHT5U9(VAMV{yRqaGYK;#?f3H%bkucGw^!Ts zY0`7g=alzpJF|)7h)woT2rau@IJdg>reLMyekne^U{Cs;_6ccI_|fHMk0_4ZO?Yk*)2_*w|xgJ7&$|R%Cqy1ZY~0uG&URG7x%` zKUiVwKo9}JJiHfZd1dA1vrF2a(Osjz6p7*jty{;!VimFO|~>T4xs@P^jfq> zVd|oa_PMNRl6fzpfZ<_$``*;n7KI!yrVrL6PITuENt84<{1vqEfhpi!8Qx!43i9*o z!#O1s6v&sE$h?o8KCP*#>4LZpn&dzqggaDk-qgd>#c(23Gj}WP@v|nEJIgaYlQT2X zp`jG$|4~!>evb`}W;oJlP#aBkbOOV}0lAN4UgU__PpujI>`A7B2dA&>4oaQv{2%l| zx9|ClN(_#7*q3e06)%u_9+ppTjgLQvw|JoI7w>jDSR#Pebj8Zrqj?(BW~hbmDlr8G zb{3YIjnxIB$B;hZ5kcm@*>Y+5As{_=)6d4f6s%gGe2}om;)Z00mek8c{h$CKkpk9Z zvWkju8bj{qCP2DD?-WKXmcS3OijjJby7ovHZ@-fFD8;~_jD8RV_f2{CE(cWb8tW@i zWu1z>(C^XzaAyf^U>dnbf#^{}EjuV?gb7aO4Qqyk94{z&TB(acXS{Ods8$i)SA*0Q zfKlnlcB!8jkWcU zNFH{)yakXN278d)6&>MORSP!o`LOfI?zwDK^ z6DAAVU!gTW)m>R!l#J%XYiLqGemq4)im*Bv8Q~RS{hxOHY(IJOB%}ZYTo4GKW2u2` z_krwNrq16Gde1MuI0+7j(brB%!}>@0^vsXfy~rdea`D(qSKa~JN~i&-;$aaHIh|93 zgStXOtoLmgpH854`LPSnx0;%3cxHF+-bEO|$%)rnq^sB_H6)}RQR#aQ*pz7nod3d# zfD>m2PjYv8>^SvhQb_nA`fxBk;oA`gQ_om@k_wC#7X$T2X zh(<|BNQ078LXsv#ND8S$gCeC=hD>Q9q9h3k6&f_DOqEDQ6B-Q7GmR>-KIe7a_y2jG zZLRlfy;!}tx9z?X&fo9*J%;_*_x%v4X=+}-d2{)*qChea)hz|iaqb+n`Xbk~lvHLQ zk&%*H9*(9rr7A+&L1~gS`=(ji_+CTw^k(f_cdD?E!sO&XwEKP?n!c2GXu746`|X

d>_*{AFiJTU ze(lycLGVpargZxf6Q57cDm8pgQ5yG?90J;41#h;bL}X#xmoK6Mw~YiGL~)|8#nhMv zB1uF*E;P)VXCEFl>!iHMLyaG@0|dTQdb0h-77sg8;`i&zk`L?M-O)>wep&19pJ69V z=5-Z_CE;2RQ`}BWf05oE)G4g{^;{(WE{-~W=%^^sc*C4NTU3jW=($Ut=S}_?IYmvr zc+!Z%zcV-iTXY=@BchjOjQTP|C&GK8#^}+#8LI4V54IAUXU^QgIU#a89(NktD;j=eXt=(%uA0ojKTzWe#Ni`T z2ZEFc!2$;{ogp`j$r7T1re(c1Ox5U>nU&S^f=pcTLh*jLtE&-Hrdvz*=yB5%)MoUL zMpEhUKfam%T}WA>Fc#zh-I$Pl#eV%)tP~bLBc7DDA>4EtJg&5sl&b+iojq?b5n>kX)XTFgvmYRL@>jXaDE3n0-OYx?$){reyzj0I{ zjn8-Dgsfj@ykVGpOYL^U?Wd$Mqgn0cRd@fsqJl!m)n^YLbW2R1;%GE{cvWNLR_?JN zn8@1FZLX6kcra|iqenG6zUD2spnkR~Y}&4Mp@wR&C){kcoj<>?vhww2U&`n}lZAv# z`<*cEok&4Xd%NUjjry`~kcYRo@v0|ARHx@PYx#NdV=q`7S1qx0eB-7Tk$kRn>cXZ4 z(;wOVUZ5SJxi!S^Wt(W#;8>Ae#!zt3K1#{TUS9WJdBHJ9GW#OqYx@Sz&Aw0aqT2xl zOiWCy+U`BdzyO0Y^;m7@)0i7LjnYYwKD&x)B5u9=`(15P;-S2Ah#EBaTZN6+vfC?G zuB1OGmu$~1CRK@WPWs6Wl)feq;5bUyEOqL^UEq?6EPmP`~+ZgY|<}w^4mHBZ?xA9q*%aOiTRa#Ng#_pW?dS*!iJYszFr@ftzo$o~Hd;;1{cG*`55){$&_7vWdy#Ww!rUx^p z&^`?4PHuWb`yN&@GicJ)^z<+bbv@@(+s#JoBR(Q2Ruoqrg8O zhA#!@#MB~)xm)ujz5eqrqN`Rom_y??ievi2v!0PKA)KJLW1bA7f{JC?}k z4A&b-;5~Zh-(mL)4fOIc%6h$|!#zh~HQ}giXkyH%pU$d|2|` z=~K-N5sMs4*LrYso2P}>){DM zAf0eVfJkDkG}=@AooovWo3EVUMYa!3l|_A*-ta~AV0CEE66Ky&)hSkeqxvUXS{Q&dr`O~!AG@SlEuhbo z<(bc{7Kr5L=1x_QC8>4zIlyyW$v6u6RYJncw{KMy6}R27-6H&wyQGS3cXl-Tt_3R<_o$V#-tQ8;t>V2d)U7 zA8orWmTtC_si4g-StQa=C4Vd2b6QN&!=3%3C1b3@Tn_1xnkIWoh>KfTT3S0uU~>v7 z#w~e4snB4dy{5N7khu8ykZ07uiidjOTruC(_2cvBVT=Bnr`Ir7>e2IToHf~)Zr92x zs<3eEmXMWzOm=p?Qf3pX#s<_>RCp0Oxh1~Bx47QLrC(S{aey3dZa|cKMl9n$>>1lr zOX9dVgN@^I9Jsb2%7S9NsEm#Kb`>ZUYA)SDM&4$Y)=?Cd5vzfb*ww3()zw9>eksGR z0-F=N-LpbMOLl%MC^xVe_;$7 zfCxH%>?|EP%yrXx_KFtQQ&~FZKZ!8yxagN(9dt<@pGxx;5jzGCx+@UqDi2pzKZ14z zU=e-zhfkjt7OqeYh>tE6Efx2Dy|jIiamp1WAY}Qv{n6UzylW{p;>BGX87U)Z@oz zX({n!4TlHugZOT?;W(TY^DsSJu z-t`{+5Y0Qi3p&Q|O^dcgMAUkn;XAnZt3R5N!KhIZMzMK3&fcjl>GQ&ZB@p}FTG=@`tb?9Iz^6ZNq~%LXE8(#{dUWpcWmkMTPdzr-{QC@p z4dlS^luf5h3E|TrF*p-fSJ_kj{n6pE}_C!Wle5#biH?vc76&&ZRUj139p5-x*>ARA-C$?jxNng#G8okg-2GRWx zY3Ep5=NvmG9Hj%%L0!QMv)gK|I7^x7$(n}ytu(+F9r=U=>L<0Au+jSc{eA1z$^%mp zGQ+iVg9X7PnFrUdErIN$tDMVwo;EG=#&a}8NS6kKv74oVKhv8Ed0RYO{!nfU&3tkk za^22{mx}T>W*#&eX=J4Gsrr`ZDELof#*6_N?6&^C=8>kH{X<65qH!TGm%#?}3X@Jk z5mUg9!2K55y1Ftwd!|lUufO-Eu8xk*h!Ic}BUjIbd!6_-uKUwPqxS16x5193wU#oR z|8Rr2YOkl@p$Ea!&_Hm3V0_O3+>?})3^6^7n>5dMk>KAbo3PG6zO3A9;D9gsSxZH8 zr_2t$zUq7RxeJd{M{mA<#%u19ftxe3Y8p!%hp)6587h5fxUIeYwvVG={|n0j{7^T1 z>0kc%aq=|}_i2%vVm~+kmkWTY*3Tb57)=$A*5MW*uVZ36C+i+jnj)#ZuwGJtuSbt? zv-3TJHf}V2vcd|D2EL;kH{y&)!Gl-Jpe^lFR@6@rO1eSgggBh9Z|j#Yp^eg1!UeN} zHg4L4wNKimxA5FNJr%*D8yhj|(hTsgAa1ouc7x*k^7U&Mtj7^>1_nC2xZv0t#I^r& z?h$6;6vmr7+6$7#Gw}hb;qZgrbKqBuy>$He%oW-90NN;8DeFFb{8+zYtVw{-XYF0h z;5`5*ot<08CRaaigAos^(Oe=9GsC0m9j5D~B4^!wbg+!S2$`!^okXx6;BU2V-FMOz zb#uzWg8*Sr^ZNGf3oS4rG5_q@2bGoEZ#;K)a;lIPOPD$2>_#N+pqlw}V_|%;^}uZi zNIG_9@%#4Cs}6Gc88%H#o}#80c>3W(<$+s-PF_H9FlkSo)Ia~LPi#Qq@rmcNw~%YL ztRq6$T3KN<*)2emZfc*Y8va#$Qs0jifN^bqEKAbO={Va&4qJFT$%LzhT=h2q16gY* z6X88nX@aqZ^RLVR2EXrtkYrXW#bspDkx9sOuofvSkLAt&^d`HMb#px*TBw>}`CrA($(e&h$8 zTFK&9GTlxOuhrUyO6ZF%+%(&tkB6nG)^Q{&if%w{8JT%sM{RI1qJ$+olPoFFw zsNQ(s=IU>0YYpUe1dqCQ9Ed4;mKBt`9^&2AQZ*H8Eu(tNNed&sKy;4V39atXLTQ*I=sbb4QaIgUEfaIg}>xS@~ zE((r_xDM9DfIe^G<=^CCtA2)S?y+~9%6WC#N;b4 zjmqy#l_|%PoO~__27^)A83{rYb#(&w2?Gij%;N0r6$|f~Sz2C4XE*9%=wSKuYue)T zcbr@Cq_6fC-#0dG4_pHphMP%d53vu!+GfLs`1tr8mE$w{twL1hd5!6tT&*k}4#mJD zQe%jGm%OExvU7@h$8yWq?mORn(Or`L{=QReQ-hYu-yyd0!hggsuyq1oRX_8bR`wI8 zH5W^kw4It-soDCUzmyVJuxyyRrT)my4#$7LoaZJZZNF^OIF(AWtA)9~|#z z=_msMTR(9|*1c(tGNCk&z84lqhmUFqRqLd{RD|8*<1=AYeXLamqhn^YGiS{3@m$;A z%--g~`e%WWRV<=-t1xpiTnqigTpGc+;GFc6C$E5pP|ZUbeD!K6b_T#sLDOk$Fy|~N zXp`y@VX9`?<7|y`)8|?Agu8DD9aV_%%?ZNk(;*?>Ay#oQ2)n;AZg+iDTkBz<(%M^Pr{(jt7l4R7zjTctmhm6Q^p))#IC`l}WEmRg z*aQU#;HT@?YPgTu-(PN;qhp^-M^2omb~6C`5)se{%1DS+W}iGcJL?`7f9N0&H@D(u zUyZXdxF@Q}BuzkGb-JG1(t1-m?cQ6ycwY*WQxzFliTfB$PFy z{;c}_%sS16x&}+iq2VFLmIbFz@16VF$@0t96)nRDZ5{GZB+4o?H+Q)6IvI(uibw9* zXQ!ptsK`V(jSKE`ZBpKF-E`eqA@kMd`8xIU9OX>JE^#Zq%}jpS{yg2AFAwKGRh5aL z0`e49+GwPk0)WBCE^s13#>G`+V?m}T7s)b>p}ApGMl$M+x*=nlt3J<_dMHV>W$zRp zu^w`lLKs_yg^7Jw3<;jkBrxb-y=_^mo|%oqHG?r@poV=}qIu-0M;sy`jEaojl41L+`PNyyMxLnrP2f67%!1s3!MFtcb zv;&bo%5czdn4=}ZF*08q6&&+IgOMXcx-%}DlVbd!zFu{}fRCR)_q?~_$$15VFJx2d zD<(xk^9rs2y?LFv=A1(&XykMgkyya}aEr0sVDbz#3OcDaH>4Aj`dHlLM15iwyjM^JwRg z{xm**+#}CMJD$f!B?J0wId5e}_C#h86ol>c3-D#TcO0Tt_iJx;@4K^*E*w)Q^(FCn z!+;y5Wo7?D#gmqn&Ub+jcdOYKB6BukwO-+MWXquSUs$1AQ$cTxB2|Y1Gp3=?MtsN)R z(jbewy6_WANdNBjOQlsCRxM9DM%}08Nnt>9nLX_7}pfD%ao#jn!H%o~C5kBC)+EjBPOL8|A)k`n}`-bx|@G6&}^lk38- z3qA$f$sO+5b2m+Bc8I#a@n41!LSX}jILH=>jqW}^{S+1L?d(KPoQ4npG!B;>PaDP= z_^(!a6tIR)O;Nhw`jP@Cp|A$92zHN<-s$K#Zb_x0(Y6JbH!3w~%F6D<$1=X^C40fZ zJ1JxpozyTne-DUR`T}v`TXm;YY@8_P77VCsHiP64q&mwTpnUtegm^_f+qjj zMlk)wqlJvb5vLx6g<)*m31A{9CP4BVB5QRYH_yp3`~a(PXyrBosWDiXnS=ElD|HAS z>2ULD=+o1&+ytG$kn1F~HOdmBhNoOSDr_FsKe5{|0gzZoH>-cRYxC%J_Cn`p%!z@yV5win9s3tA40C2@8pzm^r*+A zMke*7jd?aU!ZAOiP8t&x6{*Me`$XcmCMMPKm;p#QxcnP_3#R?B;e@r7hzQpM;Lkwa zxwf$C4#m?(qy`$j-tuOd#379e@bzW)#`&#XSB_U3E6OQL-wT!?1t>FqR2Goo1h{Tp>>$tO>DcMu0Ti7Y5yfC$YeKVzi{mJEOg)k_&&mTF-)LPR6v_FJ@?cfh0LE#ekL+1Ot$vaA^s!9Q)sw2rhz7h zu3p{B=B*#c_#Kf!_tr6GD+#|mKm_Ix53l0HDmTDbd^K9(PjBC1UU3Qvm&}=EqxN%9 z;FUZnSS%S-K{)`PG}gUz8=`57EVEG6n=H?3Nx&d{?S&M zqS8Ji{_ljLZJ+0hR~@sN`y$4X1MKQLOLMQenHl*ZeD&SItA9eZ>wzeHOS{NyQ%pzUw`PK>(#{J#&j9LnVtn6&nK79!7fdW&M8LWIEH>62Q2y7f3cSl9h z(?b5@AqxZ?Dom43KuD{2bca|DvyOdC=2|zeU$Ok>? zM6^cFi1hufGSg$^Z5?g~vGeW|g9Zv2uNWwoz~epmxDU?s6Tlc!QC6PkW|T@uOU?JTQ8mwLzKG*c2f!kRj8*-n z1?Afq{czur5g32(j)(vq%RkgEp^I{B!gF4a(3=bvFxKqkcqB-LaP(lBV7F;x3BdzO zSSzE>MHCiAeJ15xz1o$VcV#1_4S~?)%Vq?aDfdEVHMD_Od@Oh!=NIUI=-5L8gAajR zKpa_)W_7Jme`qG*i+n@`1Q)w$LE*c^V=Jc=o!#NDJ1Vh??^f*(sf3gOO3Nb>2z)D> zDJzl+*gfZ4jOTo>PFwr~>gwvb2QAlnuzjHCYMw>urTeWzhA97j_2!LLrd`1~KS3}P zm|M5D3JzBRsRXDroopvj)z4|O8|e8bJALywrVRZ~ObOkp7l37f)>!EI{$ zJw)?_cSW|k=HBJr-t!kO+|_kue=JH?Z_$xf8od&$FqybMiIx*O2;})5j$6xwL`@b+ zh79%c@wvp#zy#|7led@%Ou;hC${_3v9Y0=m#@CKteonr}uLYV6Oucg>>*np-Y%8(r zCMIC#sF(ohL{loq4#)Cd3bEGRqV6L-`G2A=ebmg&-yvUQT4rr@y20f3N;~D*txPo{J`clWes;V9l zgK$f%fBaZXK*X(I5kMume*L(B#JZW272;QFj+AIr7k6-wITm)ct8>D>!|CZ_*OiVP ze@Y)E2u>Ae|AO!s&>oN4f6CkR?4@`!cjf$^ZnnyTVBq6bb@@{sMhp(IsN2GgWfE~NKR@_)H<09z&XF2gS_q1v zGj;4VC~sjpRZ@k4QP}q~6kT=A?%~IFm`xKs5~^TYB)j9A?kTDZJpDW3qSm~uRQhl( zCB=2oqDFW=t=;hkEP4I)+J{RjcaK|H-d$Y0+)q~|xmrRXjB~bLEF^9?Y#XL9qdiSg z;M>)?Syv`ZioOy7YW4sYBnc<4B>B2Jl}%NO{nnLd4CvMBARaS9UaIw4WYZ6s4#TmOw0o^GAbeLv53G6g2+((hxA*p_ zsNzMqNQVyB(72;Z6A-n1J0~oxEr3%&VRvPK@K{Jagv+vz>xb?@ARu~OyC@PnO6PvO zdqRJ6_eH1@RQu9t(V0-J}j!%QNcA6;~_mV#|=zU|n!n)YzFDu8tinh+|tEyUZY>0`kmMww4 zSM(TrSJ{)XH};)s4tYP>ZA9d|tWkF>tt=0fbQ_9#^wjZZzN6`H1cK!1h`zg~y$cy4 z?PKmfT}0q(8`Gf^{t+bSf5=DEKiNaQ7Iq+g)`L}?s{ep;p~PX9MTVV7%h~nMugkJ} z$H2F6y1Obx%1fW^l`?^JcStTUqlFN+0Hy~uEYMf~*Z>P-i3IDXOl7F)iTr_q7wV36 zkNR%<3rQ;8C2KD{A)un#dPQm?{}V1b2e|Qsezo+_KP9dGc&3NUm8tf}wDDDEX>o0> z`VEWWxliqNfCCf-LRX!e-%jP2aUnykl$jz>rynx(*+*+f6vS>4x_uxHuFiXP%f>kE zC<=_1E>W4n}#S~^>`44oc>v0|(zt5Bi_`FVl4xEeC%zs$B2` z3dv&03J;k^w~?i<(o-<_Ju!Cbrn zs;be;L%9zL2}+Yhp8T-7ebQ7%04P@$n;^Jy>a>$?$5aY zqgQ=HLuHqxrX?_3>YblhTYdS`^7U7pl!b*(tK)sOwh^nueO9xWiQ0lHo5Kt&OtC0@ z3p)1|&#&BbYEk#dqNgCB^(%gG-dZ;m%ig%2YM~=LyM-|5p?@Gr!zNqO?;&%!b+cye zqzy_;baZxB^&iR&V4F4^?Bwc!9g3rDuKc%Bnofhuo-&g#^Le66&R)vN*>0I$vV2|3 zv4BT+mns11_3bO|(+N04%OdouinTKMV*O66q*X?F$cllPq7HLbcxKgjF5xLCkD-z5DvsetO`YAA3y|mx~;7V}zr2ga*w$xrG*{-C-0g3A$!ZWZYZ2 zPgG2>z_-XoUb^&$v(|L3i5i_2-Z*6vt~RT)OTs;=%k=ioXspQ8651J!tR1W@^3-In8ZDQ1k^@G)w1cwU?N#47M z9i#(L^VxmPgY7^N0jT}nA^`#kf$$ynG}ua2$y>Vr+SXBQ=87J_HEzJ-ipQ=@c*4!F zJXxW{82llTCgg~z>5{>A%MCJmolFnqYV{uhY>GAG&a5o0g~^BWc0U+x^26*)o#>9x zP|@>Sq~*)NG)}6U+CJspei)O7XO`HS9AxuTSx!HF2o%9b zO|7`0!FcRgGVA~yYFB_<)@-mrOStT#Bww0xg)s@88m&G#fW@Jxh~86g!!sBXYehfK zZRh;eX(ceqXl-pp9XD$KE#}!bmzbh45C(U^WAw?wLCwB>$|Ppy-GDfZDHG100qwvG z$fdv|{C*cl7c+LD=f05Hgge!XS0!<5?W6{P?%q_7xRGT(r*8#&noESCjFyBn9ROR%XBO!dTy~mI&&q6 zYEwwPhjO~PvtxYBwa}_*tv#ZN=oAo%a6a>$j+d62QKk&s4gC(!qHg2OnWaS_cnddE zcswXR${u(Em<|vWsAkjGS3AM}p4sHsY@)N_?kGyX#93bG6aDeyL{K>J7{?RyyOrrF z?z2rhinez(7n;ikASHy%snDOacOS8A&6<$}zMns-i2#Y%gFfI(!0WHSz+DL*&`)_5 zAYJJ1TJr}cIGR_QxpQqZ*BBjyorsQ`zy}S#)K|(lMW{)6@L;H-w9rY}y(|KI3OGJS<06Q=kS%uwrDFQ+_V@f_!mv*EG*Z`?O}(m{z;!YRA1E?(zh zRev;O|8qg`ydK}5)8*4-HyfanU8k{D+ZwCI?%2pG+o?|fs-3{uhd_eq^;=|3bazYr z{0Q2Kpc7blHy`Q;Azf4R75coVS2k|~Dx5ks93}_SMOyf)EXJ6kJV^A!k}X>%g`1Hx z7@q;G)OC!B^|0R6E|EK&z>#ja?mgx)_mIxeiNBN^+Vz_q6BDUPf!|z>NS=sFDh% zmakQdwLHIGz5tiCPmSjBp zvu4_$>;;=ZRF)Pv@g9;4ggv{QTtu@n;Y3;pqp*9F4!Io%S|f%S@NE{qz*LOUiNRpb z8q*E~=I9K(N9gIvUi3GA_K7LeXK;FaA+h7O(FPAM5$=QsB3-(4?6&tlA|bohz?Z~z zpD_?*2|+RMaXpCX((1nFN9ZolO-)JBQyEdBurdWq82;;+M&bRVywXJ+AWjs0WPbse|~Of)fRdS=0Or*9Aa(fi2p z0GP$FCDt;5gMS@4ei)K(fpP2W_XF7X?j1J3@Fxiov@6f~^isCXDM~X55@PH)S+0#P^k!Q? zWp(vRFn(A*cV1Iqyk-z%my`4l@UcyWVnw4@+CB45;!(_ZVUuPwV06IPI=HCd(MSMnp(BKNB$S9kri>x%NpnA#65l*00n%O@6%p zwoA7vMGwt7kdm?)bm~Dy=*xF+-q_?+4RhIKml-z~l-}C)2438R9G>MkHI6lIfi%LQ$}XNS?>pGkJ_|j?5Y(xi30}wr89O6Ee?f^h~` z$64G?9HZu<@S@RBqvFmHUS)lROYhL@OWfS}>wkoFH8FVFVEYa*m-9(0H4~)xQ zWPjb*+Q+hbs-=7ux)#c>Q)6d$MVe>Y(S4p@@QIXxo$`wbvQpcE&YV4qut(qL=tsRN zuC5re{o-%e9+@I96?|)m=t!eIF&*QY*lp8qrNXm?*xFI|iwV3|t(rugLRz0qyC{I@&85&Uhtj& z^R;58|Nme7f9jEJc>J#z?uJdSp8sg(hIU*oz7UsobC7G`f8T$X z{Ii990;7Gp?bHdry4=2`*Q4z<2{5St`NwTNoAuVku9P*-s=MEuH?tu7bj|1fdjIR+ z_SYtrsjgJb$k%I~b9DBgPU&cI-M3Gtztl`0THC~GYWBjq10m5=w_Qq_`HJk#z}Eyn zTncP3JDG*Oz_md3saBz3$#gEa3xwqgOJgAwE?Qdwr8zzhhJq&gin920d!zohF@W_N z)4)Fd+a*P9MHo*D|xf(rwSiKY{%=W~&fsfertBy}H5w{hK#})vtFc zKiBJ*GP7Z5-h};UcAr}@bJG%|GxgO}QnaOYsc6#Zp&?X^9J$8p?}h)xB)b|;@&5Eu z;_NIWI+M3{7U9llr6DYY*IuAbvpS@i0J{#38d<(Q3wNxsL|uMbnBwi;m)qN^NZmS**LHD-1cc*K^k@Zd{cuoL#QZ?hQ{G7uuE7M==D zp}agHZ~Zq?LBK6&%UfkW`+uj_MULoG9@UIx5Pme^PE0s^e%wpt z&Z^?)9JPZOCwQfr=&Zm6c6Ju+*x!%+cYry;wvu=|xxTIrucVuNLi#U0HgGSyx~1gg z*sNWDaUD{%2R73uc#u46yv~qa8Ka2^;D;fyqY{nf3qw4QRp?%v8O_$st2eu&Uho+X}?SGN5pK&}uG(9M*+8zd(9m+m*Q^s9JO zv-zAkyQ`PNxRU*~7qBO}dduTUUD-R}s-`u|Z%yiJ#u~a!nAMjgso}vvamwN5zl_z_ zms^-&qYWq{G`psOVv|p6z?Fg8FMv3OXkrKBKpw6t?KS!vwx_HQsB}3Wmza2t{gpL( zR4)|1>Dpo~dQ)<22NZqkIi89LX<$&UMc-A=PQKizdFB~F?3a7DUS1iQ_FXPl1+o)E zP1?#aeLpga5D~!QlL(ulE^BnfV#bV^61QWV`M`L}x5O&WUBcc7wq6_cM-I_3m4-ng zOk_1vK^Y12TU#@2w0SR(usnKd#ah+vRH40yn-dJ$y*LwIjk|#gUdFr5cmzVX!r3gk z`(zb>AOOI*i$SiTA?_MCC%xwUf|Ehik#X$MqT5-rxoMyH^%WIsvY^*dxxYAf)()o! z7`l))KexXSVWB26-xmmACmIODg+S1E=c}fsV$V?t8-aM28mRai z59oN1UM|f%x~kciVIJQ#O3`L^S4bQOiyzC5N_Y3c0qJ(t?2!EM;sq;JS{tmh?zvwG zmg$ydqYVrToWqZu4hyG!M&PxL;T?xe)nkdgiBKf*v)E5YyIFfiP&G;_=j}Tl!$7q#^gMy-p65R497Y z(ydrgez$jo=q5GYqsZ10tCXs22TT7E93KPG*7-8bNhk}rY2+@*mDYn8qnuckAR`FA z2CmpN;=~JIBafZC#wFj4#QNeeje|O$onJFV&_rCNL!aw3SqONee0r4wDn*%h1}M3C=P9~9wn>IY zM|an?_x64VRj&2)6p#c~aO*5^!;TX^gV2d#(G;^xJ7I6d9yPBC%*Bra1>?^Ha5Xe&w^?omSgKZieu;qr7@8I_ya7# z&d}zz4XhUj`;5cchH&oAj^TH3ZWb3bd`X>|X@?cD_y|Tpa4)dIDZWi(a)G0I{1vAS zH+??1#vhZ2j?Uiy@=2lm0V8VebzZpe3P1O7&{qN?OgV6&cCapPk#5IA|JJ4{D~#~x zGZMJfpISn+5eQ%g@`7oDr;NxseflenXS<;`jjv~c6Y~Wpe0C_dVRF&jii8}3Y?aRf zwWK^Q&W5%lFN`9W!3^LuB)tszbrKix#;(qDz+N-`c!|&>raY830re6TBR0Br5?R+93!qwkT$kloblK@Id4~do9ybBr`oX%|-i>9Vd2EOkVeTH0# zFU4?O(b!1Zqu&q6-w5LSDe0 z9+_X%l8pSHU|^6f50*6-+tRN8%cpI;f8Tzet+hI-M;9v-LETOyFUM< z#Ory* z8DTwdW5o#w6g&MS-aHft$k$jk2$`N7KgJVyrV%ZomatE|Ha|a~4~ROE5YoN3W)MI) zeM{~B2xx0u1(rMuS3V;CLzIhmGjb4sypb}26z9U(t*KhD+-5H?a{hbH=a0`@yIjt; zCUH;0w{25YQc4-`7coNpLz z7w8g~EaI~$8GxBalv9IA-MrUGEoLgIP??-%rUw;I1xC2F?AAx(!r*4g7Yv@ z`o%fIJOnBSYmgnAPnV&bK4cOI70B}+x0Lh6#l_gLmzP^)N2m2rUw6%W&6G|rL z=eYNULJAXrqJ$#g*`@WnOWauFF=dXc@|T^1+YX@LVj#}E1#wQhibTP_QehlJX~^Sf zarNZ!2?VdIhx?3qA@=71iG+dUR+dP?$-teiU6Op=QOAPk+ba6xI@x;r{*SC;NJsY zeYTi&Bqo_UB7v7J7EXq?fPT6Lf6cyQ6{WjXChgwcv&JG~*v&DOk8dP)VXWxrwd;s%D8QyruIRQ_^c-P$J(%GuuT{>eKlXx+S#>^XEUfkfmLF_58VcnvL(u zp;I<&5K?yM&3nkbcmEBejOLqKpn0#5Yt09157^16=B^2o4g_mb5``;&`~m{-QsaaS z*Vm6&&-4jjG)^9&v}`yTm0ZETvGM+I_VgXgl!ZUY-eZtAZ~`mDQ|{K*_MN;1tF%8v zlCON^fCq$VeE?NSc_R^(GB0^G+Yk|j9X5kB8bHb^anA>bS zDy;k97hmDcq6nj(o&IQGf;A6LG3FW<*Vm3n%E!}ZT&%86Q0^@VCTVe1`Cp@Ddk!7C zf*Sxgj#G0?;g2Q?3;v+F!Bo*^Ewnd&?8oM21r?PI?;q-Aj{gUJ5hl-n9UY@bG1z7% zK1J`~fddy&gZ#S$4Fa_5^4pAr3EjHN?8J}eD|LpANP@$ zZ)L<&TBv5Wa0_wZ)T*b(L1o)>s$a9Va@DFWuWt`%|IQCTK}>{UNi^nviz`_PZD?S? zZu%Qq=jcXZ7T|c-Hf5-ALB>f}Mmxfl%MJ1{I>)Wqw9{^S662o)>xpb1B|0|gfd#YI zop#FBBvE(D#*M$5%OaZ_$)&>0qX->LdhbUVuk!H8ljCDtMxC0bIcviPqrEqa7@cnI z`m-qM6|cpma0M$l;(vNwtK1V*&AJF%_63?1Y1i0UfS1F})x)R0nw;w!s&_Z2dO=ZWTYsoD!CeGLPPOF;cH+CWVUeaGTY~jt#(vxIKa9lL;00S5cwsG_$l;sME!cnS)!?=ay(66M!=j zE*USk$h05;f{@#z&s%=cWA@(Il-7K1p)TPR)06rt$Ugi9jJGCPrp%zuwA@%yAe1`X z$v{;IC*_*Q^a+rQ5z~W|bKIi~gtn?}HYvKwRi&jzF!eouR6;9N@m#t|lPffp(kjLa zhxLTxU&B%Hhw1JTy=BUpiklc$Cz$~I!TSK@Z|z>g_b1nHVg5TRQL&=x4l|^veit29 zPSWm<+5I?E@celv7yM$NGHk;D;x;9^&=<}T3^EWZGyVOODc#5j!#Y-H#vML!;^WJg z4;mV*KYXF641IpmdV<^R)JZ$LTw#B>t6Z6{x0Bk-Nmo<Yxj?R^Vbxb^yLH&KzA%_EszlQS=0 zPU$!|^52Birw`s0FcRDE-MAsSaVxD#i4IN-uhw2*j#@it@6>=_(Q9UVq_dRUMCeup98CHm||JwQwTS&tf) z*JWMoQGhvHsBaOc9^u%93xOn{^7o9^e8a+9>N0z&lUxpVUVZZcDur@nbGHXL4m?=(NY-^0gVNN#|@%y@{SgJ+wvke!z?2Cv+t88E6!c8V2691+IknBBLF1*8&_Zgf;-A$EL7fZ*Hir!YtY|eq^!j!j1A^lHS*TZScwLjh6|U*D=_A)ks`W2bvk=wCaY z#hAW3kJ&+sn(^ezI)qW}0iK?_XZGdymnqYz)ppxcX(QTht2gq<@J`X2xCwpy`2FY4 z$O(Oi7U|k2Tuc)>$N<`&&+GQhzC>$DgpKX7zK4TvrNcI7@ICaQzQz%ox!L!bG9!_O zikq7ovs!A%T#3vB&ad`}2;lBRkzfU~?uZeI`$u<_$Gx%X0fa8adyKOL^`65vN z&KqN$)F0({m2Ue52ILnnwWFFQ?wMnr zHZCKEWgb3UTwT5M&TEDqL99P#lQfTG#PH!?+2p$M+c6mQ?R_Swhm9+-7`XiY#Ok!! zmp5gNzwf%yv-3ly_Zn|2JW5x>ylQS1ne~FWiUN&2I|FxX6|6K9l_}$~0QM`s8CJIr zre*8-LmD{9@7}dbf7mcKQ>||G6%jUs4=N8trKL6Vw(;zXav6@@cKwRNjj9Y$8tUsy z%*9_0MK*>JmdC1IpRytz*EAjWV#qMZ5Zw#UO**_+pFY#BxEag|?0)z*n+l{BgsNze z2}x$Pv$~sqq?5n0-F^4$EY_dMPBPsL?~Phyr==hNw?_{|s{NCz+149BJNuCE?Ftq4 z!j+zQCm8IXNf*McLh~_FWe3{iikTzwuBBxROA48%)BSu<+=z}nHnz6S@7`55`}+1) z>!Cnqm3Fs&W~^p1n8AH>)NbDI66KLJn+%_4IhNK&Z@!*{(M>cG0`g%$-jR-)J|^ z9{UWevWBav2|nGR6pZZsuIz}mn&-f$_$)npcGBrS{E<~`?|~Wg;`Qq?VI7QXszywj z{pd3B+L!hXu?|zSznP_`P6-I&a@!&GWnVlSaXfI{b64?S57^|4@ATQThN*Y9B9EKg z=AV}4(Y`shchMm)?_>3;hKBjTd{u!tqK;kBxASaPqZh&Z0LFz)$vhuSn;?R@t#wqB zXU!|XcB02kv8NHhH<}|3Di_K9y}VFifxaWM2dt>PcI2FCv=UW0ZeD8Uak?JyuJ2`; z09&K>X2>3tknl5%{A5PQs=iRlDlD7_duw)k?1-cjC-mAdY}*2z+}Ah1C{#9rR)BiW z%0xjzp)6%VWPbFtoqh2@-Fc^bcH-#d>O-hycCX|&SpbwLSsI_JT@K@Q1x z==Ts|`X(VgxCj(SgJ15k2Cs&q8vsHqT@a=d-C1Uh*4;){#9Jeq3bWEUh3(AQs-C^h zKkQ(6qxK2+e+p)J#Qe$w+hmV4z4ab$zMSo$*vf+#*>4XM5oiw^M(hhIkZPJYzGG32 z^;3D#b249k1-*GbW5Kzg>g}OY`=`OIAiQ?(sZX3sN>X81LaxwH{DWfV$s`G#U1}PJ zqpz%&n)dJYba8ZybWdVXQ{DrP#})QSYNQ@1Txhr5R1F};53I1QEwH0~WJes6!D|M( zMn*|hFFTX7_LSr&S4`S>V$!}iA4#KC#}hwan6&B6_x=55d{Pf@>3=_>>ZRH4F)eLx zC~d?3lu>^vM*e~8)var**n}cPDiUQq%eJU1O0!3o__)}^qxF)uVkM-wCdIQRhllri zGDVc$IONdx{Fnf2_=A4MmAEx!>Dgq%&&|@=vgo{wMTb?md)Jaf=j5uIUcHL=8OVU) zSowo~FPH&i7J&N7?ff)>fEwP)Dv1Ht!Kd(*HWEVko2Zx3XhrYcD_eXIOegwU3z%cj zD^tANM~&(wAOD%WBr9-2WNBLiRD!Wuh5cPX5)vjHZHr}jEJp!CE4zsPm6byDXXTrx za>ImpmdV5!4*ISa#RW+)H5Ts0Pm!`)AP6!FDHE$1U0Yx8-16A;;*`NqecGo=hl^h8 zx&D&IUOz0>C>3Afq0(>Apf0=(d%u^Lj=oi|d-mnD2d_?_J-c)FZmockGG%T>enoD* zlmO9EOdG z$1F73zyGF3K~(ZPmXbYhY4Od{U7o2E(y%dF2H?rJ$B;=$|5|iwtE-1vXLeT2deQ8E z#>QCkgGXoWwFMU>=Y&4mTWY-V>IR#|Gl2nixE#l-@AcPaxnxRv`8(1IG1r9J5ClU8 zx)Ub2Oj$)~OSwRyDx6gKz21d4U|wFHh2Xx>6z{5bon!$tu=r zj{^Px50u-Og ztD^K!ZS9Kd*EK1CSOy`ud;h*a&;{sh5Fj}GAkYLjmS0%{LQeetXpy@n$C7S6eVgHPSIX&RFlZSzo$xx7~VF<8U8Rd}JMvGIk=mmWzZD zqN?(tyTBK`iB$OO*RPdz`DyGWp%q|*Z;Sj<44n{Ys;D3ZUFhuGeX)*+z;cI50T>%N z4pT#}AwEHrgUz#MY6=D1xHb;lEu2wP+Xv9>Ky31VSUe4F(1P&YT}~P;69osgb;E?>|h9WWUr)^ga3GCPn(9`!;N$lluzT?NZC;)r{qj0X&qV3 zVi$T7(0mH^_b2#r86rqjaYpzL>7OvzmBF~9-F4lTGPgVn>+XATDoiW)cRz@inA zrZeW~Kc0Pr2!xSRZ?D4=0B+~c%c=(v0CeWIw}-xT4#|2gTW2(E%rT|>g4>VfAwNW~ zF6p5>KkAOscgN1y5_=6VTzx!3-)bCLZ(@>7z&SdgM>e+iK0d1B#|M6z)Q6pJ%d{RK zTw(6m2}Ih-jO$=G|2gCdI9E<@Deu7&Wh7bOGzbw!CxX=i==A`~XicA9$_>lN_{fi? zj2;C@Nzd^1ty^~99l&r5W}$n^PS)s)2)Tl921baKdk-uM4#UVTY%bvEGZ1gUg5a#a zYf9nwQ>WvTb~^VH`8=X9P$@wF{leh|^DlH4(7MpU9+nAd%jUpW;-gM;>A8bV!&z=^ zE!(loTdWa_9HWwoc9g_qoAZM;*OJ29$Hc?brn8m4SZre7_fS_L1wkDt_h9j(jp$%g zqAWQR8roH8Ws76B1K$MImh783QuH9>ugv4_L1MQX_nU-U>OZdnx;s>)ycOoZ-h37=tvHaHsqAh%9IXA} zEW&`Mp24`lYc-qi9rvukRX#vVP%bnHf_jeQHme%{_`|D?Ge2hIa4Svn|0|Mgme6leV76$`vlAaiyiLJxGv;!h?1Y zO#4pm>Bi8FcuD=H0l^loteYXHveQ1st>KWVLzB?+;@4MxlIfk5{f}NYmk9TLs<8=6 z0by%P*to=}rte-lM9ghzkedT#4uhBBkPTRt%s~Y~5K2&BwGnpx_x20iSi`bKg0qMA zF`?$NMI=|pK8DY;Xu~g+p=V6DE?)dvXe!2{l)M;;O9s>a@KBm{IU?B8B5?n);yor~ z=!_Auxuvs1YuKJ{p2tiZ9ZILGY8?6a#XAfnwmvH9tit_mwSKLovi$|+i<6&;Pn)yo z^fa#miI)j+Kc1a**S3V6&xGX~h~Y{LW{+W5z`1M(jC02pB^a>U#x3O^Ov8Fs9FY<2 zV?_Ajz#%(r5RucwaOsQCw0^R^%meq_O=e1PUhNJJT+tO~ zXg-XtjA_|IU(GIN86WzXZAp5%NSa1Q=+xQtu=uHAYl2^hXe8zGJPM}Snh#SB=%tgS zIv1qA+xvNr-64BZ5f>FPsgjHuWJzd3h&m%xZR6(6fhyZfw~Gm&1I_(p$9zZhI={CB zCY9S&8;93;tb$61m%7yl@1;xfWo=0ckcMg^st`DwVU=>XD@SPl=}RU z;h|4h0yosZR~>sn>jYHb`t`Xnz4BkZlK*zk?Ea#uy_6QduY~QOnQFSEcOOv!UF${X z$Da>gJZ52r>Sp`!ULbofG{@dn&>T2kbGuwub~iDV8ETV4^VOA`_Z5Em>!7C+sbs*E z#=Tjo8Yij;4GVVYolky=?Fw~(p*-i3wR#*B?pfE}5vPB z+xXYLq;C;}#?BWD)@z%yP4tV)@ua4!H`c6QugefY5)&+UNdK0^9Q`SMDI($*j0)mi ze^&_J=`;HF_twJ81CS44JTk%;IEsr6)n!Z%#z@XVUq|&gm-B!Y{i#!1Hf)gJuB?TH zg5%>en=5~u{!Nkj4=FuhR{viA{dK|V zqbL6T!py#c;Qz&C|9|}p2Zh5ZE7W_~XXO?rovt>Yl^;?`(}xK;llbx z*sNpH&ieH-U+?x+p^-*SOId?kS74Bl%PM-)LRp~TAfjY9FV*wd;$k8edFaqAoKf(p>nuf#ZNXm~Q7CxdO2*S+cPOVmabyQLba6_~>T^ALx>hT?=3n){Zd`!rWT`=PnjopY*1OG&Ak zRFC@k$c(b-*^7D~&zOp#39UWvM{M*y~xpW2NWrRPZ#F8`xSeA-$z&g=Wy^}3hln(;p#BlQ>gx9F^o1t=g99JP_2%Am zyU*XgeXFfK6g%WJY%wrVc+0ED6(*mwia6jw3Jq*r$tAl->e$E z-&(;&3+>M|Xk$lKH(=E(V{&Cy4507+8L}w)~rNfF6xWUv5Q1@)M zL(;>Jo+fN5KIhR%8usbp^cQO8+p64I(3Mq?-DS|Z++uueVd?2`=2JNqIr?qT9_E`# zd0&V#W{_Ca_c1d0XOub7oOkS?r+mhB8eat$#)mODg~kPXsT@auj^_sk57rbBFfBzr zXl&yfDmF|bZ9GT#r&VORs|K2qYr-ywp<2n4$-{Ac50d~{9_!abv5j9J9vPW!r_nGh z<+VUSh6vJOe@u4IZcFN91rS|Z-IMaYDU10P)Ec!t8S^|Z(T8|~`1|{bYjpm0JmPD?UwaSg z?NyW{S>!Xj@nH7V46t=}m5J|?=obS5C05SfU@N!WDj(|Z>)IW?-US?75Has ziBfNRd8D3>!mjY?r)JLY>>899qUuwhrsGIwe_+-ev!lp;cH2D;`x zeN-nrWT6ZL47&R9nS6R(!ZCcujva1SoY?mfMKe`~@_?fP0li_afqDqvV^G$-o!sZm zRF(LHx%KDhMLn#kqvJh8w0o+sl?_bPvK{Of(gnJ;trQ5c>m(AH!VSd}-E8x_b)izP zF@9n=39}^HDNB^Lc%%CS5qY1mU8Q92tccf4-L0&t5u0X1IZb^4A6DtkNEoz?wmtkd z4cLPEicXj0L6@qh6^(SJ>HTI%3H_AEqHPTHO)ghRhVBfIP8-^&FDG4;hR0Qa!4@kKZd zR%c07+8m=&DSq;#l8ijf_M~YzEYIsPF)pNL0~V%hYy!Z5S)MfF$g%(80zj;3(aVGn zOl^Z|^U$|uXKXW*_y4A9XCAJ^9m<7;!J48$Ps#nk4d2BOiL^zAv!OGz1cgl9CeCx;m^qPX{vy&=Cy?6O#e34qFX!-L8N#`)OtbO|R|( zid+u_@d80;u8?RP^xP}WccxqszMOzMS;LN zIrP+p@=GYdrA_tZ5&UrYLlhb0CD2l_+Y-NnZ)xwOCw%_3XV=j?5|&FL-O}*JxX0=P z0;%jZNJmHR?RHGegjL<3K_VwTXfVJICDM<*etnkO#wB4Ek=!kCRp)XhlhBY5GFfAs zJiWYt6mDO-BuI7wcfw|p>YKDMDDuFiC`=D$&_E?>3(YwWqY{5pP*A@0#--LM65G^$N zh2}yo&5g9IV@cC!jkJX59ExBpjOZ{rc9hq zt`5Ks*1pPyv4TuFsVTe(1+N zJw1^jbwHQ{VrE(?D5~Hf09M*WlNbS&K&?th7%FH0Qw7x! z{^OMG*Af%Yr`iD`%D|N9ZE2XFRgl2oMBoCbr^~U8WQz(Lwm4@dByD78lx_Bet0=>{&#@?G@v?{P;G!mR7 z+W+90a7s6VgE6ic$OEt@Ec7Ap6)Z<$h>vX_I)Z!SMd_QX6$`~`G$g#0w<`MO`#>T71>(g@9lG?TjlhxH$qLX<8@Zg}i6eV;;{ zh_X&-qFqHpl%fjYLO6u^+IqhY=@77Kj#Bw^YC&1#ew~t;_=J9AVc||Vlsv8=^;OJG zgqN_x)fMo`J6BU%`@H5_Q;76fOtQ4`82G&7AtuT6Z@N0$zX6?O?I_s)>4Ga%3YOMh z2%p5u{MsFSJHQ}Z*1u+M*xqA+J`*X}mXxBfGlfKEu)`cjXbS_E#EE!V6(EWaf3?0v zCBQSwq#27|o@C7B%az3BtlV7kH`HWUd#2R8K?3OS4W>I9ud_#FDb)6ihb=(!y|`=U zs*2B_uVYTi(Ru6EINjwk0|qp~@(-vLQr6Ez1O%QM(fiWtT45`B(#8T|=a_Fr-$Y`~ z*?{11*&KLi=KX0(_@6&NPEs6-I&>9)*6CW8-8NYipCVOcEZ6$2Ta&ka?9IBa685#Q z=|KO|wLisArI3@}T7Afkl^R-ZDc@p zbt6W>6DCb!h1lBMG?}YB^n&2MJr4vLe^c>chKBu|_AQZWnz6CxFJA0_KfuGo(9W)$ zqrs)y(qi`kbC%p_ERZoTH7&n*`mO$P*#QGkJHnTH2SY+2V2!XI&rY!IFFHanGu#3s zCd1wZe+(uF2;}}Zrlu+L%vv?y%)h!d=@wsbx!7qq3B+WuDrVR7$r{s>mU9H zlhWYzYfKQ?p6EXWf)XY=!5gOA3M5Wc%AXLfZ76;G2i>e4yLLsvtWZV?QVIC>)OXYY z5MrplfB7yOXOp~Nf6#AMCVU8v8kd9;x)p+%XaOjx5vul5iG}t}sy`8RJ?D1$#8|oi zjqn*bq<%DfP&8un?Wwa4_3($YcVPi1^3)D@WLbo9%KS3nM-{|FKX-F$q1Iq{)qUni z*Y=$vyI|%_Bwi^h5}e!*fiqK5X1W&=x&Aj6YGO4|>7Us@o9NLtm@ZiG5?v*JpE$27 zZO`?zFi~9miM0fb#Fan~nsog{%Y4e2eAL0&Z)O4JDyM_0^%dORJ_Ry0(YyXXXEb$j_J{pnMx z5IU1-11wa}K-tCK+%YGFj9f7De{eW|oY=8COQ*7{^H(xQdp7~Csznfw-5Tvj`g)hS zL3$MWQ-hEL%9xk>kgoC~WYgKh_GaPKZVoUrQA=)FyY?fBh}CiPO-$Y;Y>$9V78^_N zd}45o^6Tx`l+e+%{?~)l48hQU@Q$qt$_L+n@oUM#6Dy~t1a49hHhsv(GWRC?$ zOM|(&c{JQ?9?F86d1G?&va=a?cU*U!$Lb=QXBD4xFAAGzN4F2BCS{*j^*-DC*Ij{t zc)pPzrMXTz8&7Vw7(uyF6vzGKIn?@pcDY@3lZ$r`i?s?LKB-Z-aJz4P$l2d7QQa9qN* ziL0neVWzB8@xLb8f9T1Gx`R!-wDhFL!&7F7qu4%lsN3J!@hUAurfz}(eY6Qr0Kb#6 zMK}@#2&SQa`=}hec+Fh;A4wuI6mA5hpv80o=mm7VyorVE`0+cpZEGXDfLO=ru-L{% zJ@?x&vm?=S+1Tj{sV-hc`by(^T|d=0LW~5M@j29;HO{&`2dJ9+e*PVN*uqaid18P9?jjn{Wcu>~F1Iac6aWqVG4f zFTGs+tt7!i;^vzig-tgU=9KgO^dsweGXj1>OMO)5>JHd6NT*)J3Z8hx!?Of!|k>&2Ph#VM0# zc;fRXeVKCO5#w$QtMV&u!Yp3oZTj4^i$`C$bZPutCR`5W>U3Lu;dJe;wCKB z8g61>v>6hKf`S6@UokADYx6l_=Pg(;-SzAF#8!=rp-=V=Xi*mlo@4P0#8((su_k0? zP>GQa{34#BcQ~iF&RP@yUtAuT`PxAdN6ihrJyg_l?c{$5u}O?<_0f$K94fniEUs8} zuweHKN;A$m>H-|A7`y=3o$0FAx8B1U6I{lkv9V19E;S7A@ipt~-+wcWhNu+1;8N6? z=`ww{4AYu8anXHGT~2n6SNG{%G)AF^BqG6L3?r=!TBtDKDt2^_dUZzHu;0Mtj~^=Z z;Ep`M1&m!pm1$l7}&rJ5~4LhNiGIVKhx3LaQAv?QwY&&@S%8d=PEUF^4FC~6R+%j-C z=_fTQwN}QJW`NTZ!*%*E%quDq7nnQm2+kKVxA2vna=EC7JS-d)6>p;}EY2fdUSK(= zavW8Zm7jDSn|)cxjEr|a^k;-v%58(bVpU-88y&<)V+^wB$%bW12xt2lre?dpH3Npr z`}fn-)HZ@_04P|rS3^WVfrRjvs_asb==Xss&(CAGK=$gNkh*UtES3Dlrnd9G zBu!Y0=ZwjC_N;fB#%Bvo-&m?oR^N1?^;#cy+d0f*4tgH`x^_Uol4zN_&B^AEqRs%V z#*z;fg#Zkt{+@2KYkNkDwBK2Uf1e?x9v_@)OYe%O+4iyZG0vu3ar4v*%$A?j67bg`38SWiMip|tt(H?KkJ)T|o^)opt?dlt zDD>!s6;Sb-NtPpg#S|qP)SH%eI=$%??H1jBedCJL1#W8nRJZEfc|T<6{sXv1R&`g4 z3GDgO4pY_v696rMUqpKL3@{O_Je66MJ(dQYkl(3S48}zF-ry)CD2h zy^b-h>@^Tk+~)U;dLT)5*rI0<5|*#vCr>v66|O`!)^@=FTx6U8VW|$s+h4Y|wl1-< zs=6@yUm~mV{=qVnPgeHO3haH*?2>H(nyRdxQj5RM8bE4J>74PpFyCR_I=IHL?h*8b znfn|_L09}Qf%Mj`)S|=j#>(gQ6_Q`{e7+pd7;Gr3s&IQ*v%ayqrp8`IT2}VO$8AS# zJ2b=njIMCCleM)|B4;nly>WKd$%huE41OUYL+<@PKVNgr)4uM32j;d7dXbKztMcLO za$)kXKZ1w0Vn5$nU@de_x329CS`ImBA>%r7;Qh_>fyS7 zCOk$F@lv<{#mbbp_%AE-j&^u2U?JeSA+} zyK=>lsn5UL?r>*r-Rj$?Pi%B_N(90PgD?&19EI;V;s@grow%Z0oQ-uLv7r-_aE|f@07+yF1a^@#6>pX5V3>MvS-}5#e3dCvy&}o{mZ6ITXlOE(z)i11&DOrJ0w>9N7mO$+?^52A!az zp6dg#G`F(ayL-1(+ykbBKn9z!yJ_~dd8*k)vT9e`4~r0!Q7dq4wfK$InERTiTJg8? zp8^P&1v3O_?nie~3Y(jmop|wcMB?S8fglXskQSfw{^i06ifoBgTScb#F=Y-x&AHS7M1%j(3SnYd1-|u^1od{AVOeQ-% z!K>i}L7j8S$@9iF(cLGF_s{g3E?2X;a_6)al#|VFiTEARj;X0}#Jx#3fcd^mBBX7| zYT+>O|Ld=KN;FsiBmj&g(VaGpk8jvx!G`W(=4dEWAw)JEpO6ID_!Q`a9nX0SEJ{H_ zq=j!|zH}*#Ei|ncLqi#Gj0GdbGr~uS&2qw&(q0zM{WHkpi zpf7oy5CW|)o)ka1?%TJr($Xz3UbWstYEv5MjoeCSB5!Re3{j~6X&RONZ53cdQL*qP z;3g`*^qdFzC-34Jvo z|CrPVi8tqVOIw@&lpfP~FHbE^c1mm-I%%WjA5o(VZkb6*osd;42FM8H zxDYuxyUvbQaw`qVYll@_%`S8)BvAb?8eiSv=0uY##kkkuWoja+T>7#@fW@sUZ zf}v!|0f+TGZPaw3H4zxzZH=XGJ)zqPcy~-gDFNS=Kt{$1&mSr zYUoV`s5fQGZYW-yM)MXf#Qz7_oD!LO!MQ%rY-9p&8%jX>mX~ zdkFk+)V}Z3sD{iJ_sxWG41k{^ab?l5{1}rB(@X`)jI&r> zrwISux_^|PF@pyUq3`apTO}pd?xuk+=C;-CFc6w9XB?-Qj7ZA}GX>C2DkI9KCtWu- zuCS9avZ(TtHRsvke-juI(niGZUo)v}na!_P@}M?+Ban52Oa+Yxa$U&TnEU|o>4tXE zkfUCh9{LSxEhDe&>`q^}0J~>P;^RPNVS|8?XasqDdJ*A!r$yep*-l|c%WL)Ff98NO zCmJ#x%#KZ)ge_L^dZ{gw89qVlWZ0B#!l}S~qjd&^puQ8WkBc1B5)qJ|GwwH~FR&LK zD`qdVZ+rdDwp+i){>*-1Z;puw>Z+;-KLygr;=>Vbv9rOjhcueJek6kv%c4|9SS2$z z6Ok%`p}2V1)Q;>EkA+6LsEB|a?Pse9W`Z@ZSTXGTQ+&(-2|&hi`nh^o?i9ry2JDbP zG@91J&%im&kDz|cuQajgkXy72359g zKOs`s5SvQ!II@j4`UMWJ&`|Ec9?WV*UkxYWqC24=eVlXGxN;2AVWw^8&|=}jA^1Q4 zhz;g4``opZfQ-k$PJ!RZe}PQtEeh!a24UUrg;4}s@9i!iZKg~D;!<%DSwRaL z^8{Q5Se`s4zK@m}4b_F;Z%|yo3SJZYj*ZBTqC%mFFrt9x1BGS>H^*xGQCLvM8~Ki`hk#gB1&sicTY5>C>~a-?jXW=|NM|r; zX$+^S+uR7qGc=KTW7z#!dYBnSCnQYq&D`$ldH`!|Gv72E4LLm+b-zBYm)8^ex^Rgt z!)ulnZ|^3c6^H)-%@wc_Hj(x)PKbJ1SEq+u+|+PvMd1}g==P%r9JO1La{D2@59PXg z-U+P9+TcTSj+3Mc8`r;nQCr(Q+|f(Dg*|+IEI@fEb2z1>D(dU|4Q@e@d(!(*52c5F zPlpE#_8kExLDOFqA!oupfD<0ylx@{>!D7fkWe)G7q1v)qy&?V>pLSr?_4M>~4$z8E zC)7rsJM-X5`YU;Xz)gLyUih4AL&uLVY!H_?s&(Y)sS%LP%55fIdgFH0a%A~-*z7=) z(j!NHt*e8c)F(XT^StpNHV*JNF;Y3bp^rkt=hGy!(20p%kpA)8AOwILOL^;s=cp{Q z^dIJ1ebAo1lBA^h^5s+hite^mIofHl!&fi&Li0#V01+Xo*46}n1%~L6kydr1v6=wv z*xOj?r8gRT2pF2L>9I85GXik3kAc7ibTs$tPxJHgf|&puc>E?M+0@r75<0P8fa!hA zNg@-$IquXN@_kq9Q(F4H5E>O`*n^lnpEe@ogdjx-5T*B0Qx1p}S_u2}^Y_0th!X%$ zDwyQkB#vRM6K3d5dylaG;#6Mle{liEYIVDn&-uOE*th?Z?7;|_v7a0CVmpK^1D7Sv z=OaaKyeTYSRGc#!RvsI2KsVWPX*bblvxkMB)V2I?Dy*9N@CyqFC1?6iGJIj|A}q2b zWs(kd?7x;ecHV4zP^#er5|75XRTw0X9V`%Wy9|Tp+>|op-$s7DD_1)1(osD^N#km$rSAO)QDIPh3DhDG}vkk1-2^w7VIL>Ym$ zpDjnqf3p>b_t*3jPCysF!Ka(t!2fvX|M|lIKW@pv8s{N1M^H%<=?MAL{xX9v2exs5 z`X@|n@W0->mG%G9K|MC6XJ#%8(KcDUJ^CqJ*=tDAFD#g#sw(5pAo+imLT4@Jjt381 zzLG%$y{xF@zf5+wcI(nJ?o*Pi07*xql`?cx`hR~-JZr{Q>DEz)2xJ!)knS0tt3S6U zk92!CSp{cS+7f|sKmxS`4XEH5Qq)Smar_N)zzX%OJgspfW?5M=H%XX)#N<2NiO@5A z$w*I(_pZZOzA=hKB$l@$P~kgheo zQcPDXg=DY`xg;B=Yxl=Lo2=xlD?9m|e#_>Ormyny%9~paFArw>R2Cj`^z-%y8kn#w z8e^SHWbBlLbb3@(*TYGIh0zd`IE-hLT#&Z1zfp76=~Tv}c&uLy)g6ktLzKq|>x zO9tf|h%TK=nA@%zt)npC5vppAGREM&{KO=w>Y)?uMS`C18+`l~Fr2meFH#l5i_I&I zY-~yoi}tbkaqQ3^&L}{tRc+MU#goIYTe%6NgygK)EJF_I=!&DISfLqy3 zARK}&X4q)0`G?wwrztO+VU(5(#%VNHFjEf5{e(Nbn%ikUnr|j#UZ}m@rH^M+H4bv% zxW@hTu>VaF$fJPz(b1=`Pb^X&kO524K|>}_4z(=GFjj>?KsO!l)h|pBf*nw#PUv-? z9tulS4-$j-%rtym=Zb^end1Sg##2yT9gf6r&!;yG;zn${@Yo{JO4e|8dFM zO~X#L>|)WmPcM7+>{!d}-C#H@5fIhgwl1h67|J44%!UcE&YhxVgC4d56T%BE`55Hh z3uX^+D$i^6=oGQQuex~Sz=EBuFi5bW0xMgHY?X~}1D1?mG+CC5B}+oGeWz2N+sNCF zP1W8$SUcFvM0v4W_=Q1}`}t-rF1MP0|uYQh zwuwH?SA<(3Dq*IU z{Xqk1n6MyE=!&kjw+D3aNj3oD$M!ef10@>0KAd~X1$^{mz-vfGW3Jz0rozcl!ilf4 zf`hARwiZ)JP*gOE{LA(Xd-R`d&q39Z0;UciuK~hHJCuw?Cij_YvmqG-8?#cfIcoeT zTPvDd*?Cl_T{F5qb-_Xu5fgjz<>u1c-35P_3T^D){|`1?qZ%!}4lD>8TwCc&zzZwg z`%T|{^r8PGPZGh>-lhZ4@PY`mjkyW-6=>NQ`UGDd@mr}ji6|L#SmQ-jjbiwu2x?H8;vM5gJ1J`g)I=4DQH$U@Wv1UmnL;6R-khn(CtN zhwNbP4hTMNV)(z48_x#yB`DgcbR}LfM3>yJE2U>xMNZEFGBOxEY{S$J2C!)dr=%0$ z2fOIh6>wk_u3#x6P(=cwRHQgUhoAzgtAXCfHLUy}{WhR2OQ3nombv-ZL=RL>ec0a+4*nRK3)^6xC z4$|!`(LvD}Hj4(%w=bT5qEtyIM`lfkj&g!B(-*Y$H)<=tS+;0)biQAyU(ycgWp-cO z56FHPw|TkLB;{|k3ymEebI%8BO>2J{yclc^&Wok3?N;gwX=(ISEbfNMxq8iJJY$B* z0ELAF24;NR#;UYlwF*L9qgY=tij@(N8R?a?#{hxJHGls2F-M%L1p~rvqY6Maf$$b^ zfHzf}c%BA1;Q*tqm}F=(zYm%Lj5gV2Y=0a}-1Kbz3|;??)}uiT+KuqsWaunJ4n7RW z6TWun#$YB$84xaG9dn3-X)4Y|p^DT>c=v(HvUp|j48C>i#mkojWutJ5ip8mpv7>fZ z`%(T@S3CTOz#5N4oNoo#0P6uj2FQ>@8+$*r3Z(9|Ui%+^FKIx}j=Gi< zUbjh6yb|aK%f@&s)l!a*r)giL>QBv2Xa4RDkpv*8wA-rmetx&HpVq)6intz6Z>XFN$*FBpN6ju4Fi zVn%HLj2ZCe_mkm~(%5I=L$n{Qqy#BkN(9aG(bW@!4Z^SkhaCm*wRy{yPA-=6?_5X- z+(t43dMlj<7hPYp|Xcp0zqrW9O|0zO+1g_3c}5%iqk@l<&KP(fyvv^j^Vc z(kwyzZ~gol$&{;R!*%@5m$_-ivuBq-CJjIL*I$f46ZVoa-?K=F3VWdd{5}FX`aUT$ ztTlaz;|8B$XTeH?mt%X^A)9XN6wNijfd~Njj3^u?_df?)klFKi1~NJC(9DJ>eJ0=o zA6L4k&kXtVdb_tF^+)xg`{}UHzM9WOP?Eq zfn$aY0RYF|43qCGSb%M>r{gEI^mX445oUU8^+37^U@SS$l4EGq@LH^T23k5hAm(z% z`Sa&F@&WK<_fMHRm1$N`Q)EX}_EdtHOFEsfIKjH}=liWyy^G-=18Mg5GGHm4KQN+C z6regbrSWp$KHP}Gl9FVR#=LT3M`b1=_ZDCefqSsmK$1O)-tt?CHgi$e(B!+fZaqIq0ypjfV`)Wb4Z%K zXyrE;nAVAIi{vC>s$6^YXzwzYy!H3^oRxQ;f+L9a>{(S+1vRcQwSN^16Gul(2OlF? zPJGRiVErK=uZ5MdX2fyR@O%r0Dx2&zYi%;`3U$-J{(k4tC6vJO)6io#Q(|#0i_A*I z%@E?@{kJx0nXDPV(9mSm0(_dsP0z-%#e+6$OT0-z7O#z824poH7q+Y@Qe{`PE2_-(tU-Sv?ij6>lHG-d;2T){O9avBSWU z-9X$*OxU{_;V?GvS&yFmdBnldyH93i`4>k zfBWHe+3Ro%mo&-aBHxPjR)v^6VwxRO^+%7Mqkp)ki9fAUa(}|qK;6OoFz~rL@m}%; zI1Y8RxG2?~ZSmO*(RBm znX=GnP~L&8lO#7dICA$nwLg;n2Qe|JL&DVGUuPaGq~zdKJ9lc$x?d{kr;sdwmpD4HXVbf$d^_*WqYQ6C&=Ci23S4zAt19)-MZzCKoRyYLJC7 zqdaW+hsZ5|h`HU{kW(~|)WQ)u^n?jpmrYXIQh#H5z6pK{zD{K-6X!=-uCTUd^jwBx zY)>&UFn79*?9#WlmQyMa6e{cLOs~8uIQcea^;VgNIbOpbExc5&(P6Xy*2N1d3(h<5 z-1z~kU?Lkt8)i5z8~4;IewcR}VoYRN1Z1W^1l@Y*yNV2-Rxs^CH8nL!1FAzXN!LLJ z8)~TL&dB+YLF}FQ#G|HFHWGy>?kcZeH`AEJ4J?<`0Yc8^fO~ls4g-_?Rjl8%)Ew-w zgIMM%seWelUWwkl@8bG5X3VR#!xKF@vDkF+-SVT4OXDWmS8n!2S!h_CJ)G^?W+XoGIAkitc22=sBx_B&}f*Z-e}RearZ3jRlrRzBBVN`Ffn-rJQ$D%-R) zG^hn-9!Pr6PJ43AZ-bCqV?AkGqK`{O-s)<_-ut5g?jF7t$@IFJM`om=^NyN~q6gIg zxGFteArj6M$00*#Hg*AM#P47AYyFz;q7lbY4}fTpgiljdWxoiCzVr}6KK#IdWb&4h zEISS81L*$J?Hi+8b;-Yq-{VD_M%<|)ivmh+9ZXf!c;rZ~bhE_m9-y+wq{3O#GFmFf&P5s&o*A^p;5~ zEbL#>3d?d)^XPF;tDw}EMqj)6>yVi;^<<&u(7&wJ5wZ>)>Vjl;r0?6JBA){pFsBo& z)n`er`F*{r>;d=McY530%tZ2k=EyQsG-m&h#wIJ9-5*Isr(TQP=lXfq^{c%kDkpBa zT`E!@JGC~sjiineu9+Lv+e^u$@@Fq(v^WroOnTiXNFZDfvj5}S@qU2nrYRC36+DkT#fK2Tg`WoP$_yE}K@ zymQWu&J7X!qMz(Fj{!tM?_VqM#`8`a z`o>&SNy&i&qt14masal2cM=ZK+VL`2Ks5dD>XeU{VOr-K1>He%D*!7iD~HA5f_N$VDZ*x58_P z?ug@#KX|9B#QW29M%UB5n)``RgyNDx931C9N9B-hl_{pj2Ari_5D8wq+^$_MNx!*7lK7Qeo4x0SHOi z#lwW8ZCw1vi_4dM>UD0$fuvJoBq)A79_dH++=HJJoQw0=#RZx^z9PB1yQoeS(Tw7u zqdByzV4QA=vvjgoN*9&hl2CP$(-Iv^tv&D%Z$(92oVbP{mGQw+UeS|jLs;)}U^{Mr zg}HurouN343F=^%KK0$Y#=P)Nti2Xp={6z%Ln?gd-!K3u(=3BoAioR?>x&m@WT zk`g_5a-f;iTx?j`Cn=f0Y8uVOIK>^oPP!Qu&-@O^l@#3QXIXUs1en1_By|tR} zEGi0O)#5eCA}LydpP@S$1OJIP8j#&Ze(T<>0bl!71uCm;(Jz zdH@|(8ipmEW6NMKPWQtx;#rwHZs(y1hsGzL8xx>L{duPgF z&#!D6IvGS~Abya1PDa>@Ah0z>?_EuU#_le7@nzW?A`nJGET)m<4o8oMy{?_n)uFXN z)4;zPu;>wIm`<6y)|fF=$fdOT)YYRfgA-2A81;7Txz!wttQeuf6)(F)9^{ILlX?hJ zev8OAITSzM>k8=!a2x*u8w=DJ2f8|s_M0#%3dZtOz$mhT*49;wCUWpG^RLFI0k3>U z5}|B2_IH3>fc&MPhq0XJ^tR7e<{X!1qTZ*N%#n&FcOO37Lnkj}XDPNkJU$^lVrM5e zvYR3{?x<~$#MEu(lg*p(_&bMvRWI|o{+rgXFM`RzvU&Qf10WDxj!-iAn&tIA`9N~b z`r&S!$1#qYu;QM$yprEs;1KGiNinwOM`u$q+|g^T4uD_-J%3w5BFG?|w)j1L+tGRM z)++@^5t=XiJ7D1<+(4U)uLdtbv<&`$?2~qbt83iU*`4ek^s%Th4JAE^eG3+JVnt-cgd98Wq@bI_jH{PQ~C8MFZw_b7{Z!e_o6^`;2k>9ascdH3r2XhPOM+eR2q?imm zInr-97C(+tsV#Pcm_S?x$DVBXiFgq4qZu2&Og(Sjd3>-BI zV%hW0VEFO~n-u&$@86Z5|Mu;=m}Q%N$8;S2FD}65%|LN2Lmhdl#aie$ zUSXdW6d@XF6KwVdI>YkDvUn@aVT#QU&rCX(Gg(6ee*TEDac+7X%vV1;?|3Ih=I!ac zb?Egu7u}Ui7HqnA>*%Yswn`?aCN+nCtz|~qc*Cqn*Budy%Qi@Sl`<$+Y@ar1OaE|- z+W7mOYu1`52UKM5Hx9Q{%hW&gaMSh)W|zvJG9}{DM>7t&569maU%|1E#)R7my$<&Q z($Yo|ui;=l?U$z#uKP8kws3ULBCR{mds5Hbn=v@+W0Onn3HA2^K}x8!XO`e#ay(={ zeHx0KUTnUvU%nj2`shvXm#@&o=Gdp$~_{(BrvxcN2dOW?m9g75OD2yIWE^ z^t|9bjg%m*s*z;i>o(J9VH8k=qQmQ8@|2Qu9TiUN|NKpmqJP-7f0_tWD}}GV-Q+d@ z`IF&HpT2+olp-mqz(?SvA;Jy$^ZGw;`*R!qzkNgMJHPp(yu-`>q*0i;_~H1jq5nL+ z11pqFZdj^8OEXusR4eR}o%XxGA3$+H_En3>U%Qk4ysJEHrX;OYpadq2s|8}1(;n*G zUYrLDl{T%9@1NC^5`W!N%_H@a`N)}RMME<&@we7f<#~tBc8rNm}O4mHF=@a_9B(I=t);fnu z>y9eDCpGe(X17)inckk@m-2PiCvic8Rde(kh4`MiyLy?WbSaw@&a`!Pb=thB?b~nt z5w5|SL0d;k#@>jVS1{&A!k0~_jvAPEw_h5#HuUz``J)~@v-~`uF>6MXg+#$w>(|yJ zr#d}Od0N#HHD6r**Xb$h-5TAdOm}zLyU(TVXwuiEIkS65$=*<(`^eQeY+6u?+vFhI zGy6NHf75Yp8>YOntL|u5#f*d_j<2l}cej3CH{$ZfmA>NZl53}oY5%%GOfT)zuU$vg zn*3jROzR~sKQ>}!*@owR=imP5(IZEvIX1VepzLk9&2uejv6pTyUOWpl2*0PZ`p~Ii+{GC4>8etNgV(8aDM`iMd>it9Pn+?b>OOy;F`cv0aQ90u z{S^zwyv^HWR9W-nrt!Dh@mrLfg35goa#yu{OndOO>CwV|2hJV4GivLtEa^h&jN5zK zKU)pI+T*cnC}Ge}!fE{7i%EV*!xt}Hw|(WB-4WN1tg-uhM1n)@kq_=#=Xzea|0u)G z#cbS#WwNEwAN3Ub44k*FW9z9yjibEFtZ(NIaWhu8X{&sD_Ee`^kEx4u)ANsJOiaA9 zJXF;<^s`s)@1S)G1$P~~fA0HWT}#KQ8nqt-hIHvld%u0P%~sK)U2eylwuFW5+HY1k zy*=$|9Ov@=o=?J{C+aHIo()Wtk0c7Esg!vwF|vYCEX8ytJZ8`cH6ops%zOQ zm56=gXPH;NO{^NJ8FBNYV``wH*<>BZcTQ2Ww2W&~wKmy=#?KG8jQ(-5CaO$+$9Jb2 z&o*Cwc7I31h31xBXD@au-Sl$cD6gCqGj^-ggstB6RdV8|cY$5STKiguzlo9`^z8Jr z?5pLw^-S zSYNDN*52f6^X~13kiCtc_U@gls9K$$+Bq`ZVtVbEh!L$423cd>JFiCc6|_v-yUfWi zCh>UtqZLk#$+uS-R$4u-pSC$8!#4B8;7)sOiL$Jc{i(5`3s=T(=*sk+USvIAaBj_d zhb;@kC*B(uEZBQTwYyZIXq?)lUWc83hAs-;JJPRu)!JXV#$|aW=62KGtVoW0zQi=6 z@k5vF-`{45@7A&MR?_=gJ|$f5m#jgSmVf`Q(U76FixztrG ze7b7vj_*3LyT&H&QTqKtM?CDwSnq+hAHK}kymtAUVEK1icQxXgr~j@xCoXUNZNoVI zQZL)-x5hj?A-VDP*i+|!9lpKSZTZjTo~MZj-*QWBUpf2;QaPhEWf2Y{D`#wvx z%{9t*Rdl_&&hGE(;d}RYwO8-!lGc7(J*qu<$hSGqdcU(3vZ#I8)Hcm!B@d3= zJyz1J7P!sfuD00ovG!r6@)HvGh-^*KiH+ISPfJ^Bm(NCR+op$N%90|B0z>O$CD+;4 zD%F1*^L+Zv)`zdTq*ppdp<7D}SFAj{bJsSBP3M*?XtkmIonwtfC>Ybk?~^4r_@ z6wlLID6!Jt)g?K}&+B;EcW&cT=~HJo!L>_Mudo8=yf=Mh3$&qAicW#+HXwE?4^G)t_j zMeaiF0LiWPwV-3g!=9;awKg9!_?p3x;cq1kwVZ}ATl~?Al2^-iD|#8N+cQV~#)Ri7 zA1~()xmOTW-XbG!xpL2Vr3#mQip}pc*M?cB4O83R5W4x~$KQ`HJdFC-N76udkCFJ8 zz^1mQaXs&B7^KD??pP$*snzyy<(x^a;$b0|=UqN0IwGeqZonp0WAPZ>BR_2xpLnq; za>tRd$enh7*FC*qy=b%4^tiMAlMIXx1e(6HuyT)WdA2NOMDFa*)dOGi{Fj_QmU}JX zWW-4)@sp_=Z%>GPmh3ab=T7*sLwNcq zC&rJid~htWdiDHa;=Kzet-ls||J<$QhzBhz=IuLLvtn`B@$f4%a%Zb;WD9F86vYbCj}pv_Gq@1E_j%l|eo{qye(pMU zqqq9c9ooNL4vJJ%4|j5`3bV+nd#1Iv$G*u?^171;Pwg49@=;m*_nT9-PECJheevDN zc|D~JSU|hJR;-<3vhTy~fkhK8gnb%z#?E2OgxW0`!F;VX52 z?l`)mm*3EFk9RNJ_srpo(?wZ}@Igg;wz*vsmtWJzxPR_d%?Z1U26&$CR#q}$Pk7+G z=tD7UBd&+6+f+MJ?~&5)^uhpJc~8k%o-eJ-NDYQ6G@tP5KHT@k_8A>6e|2-Y9MVx5 zb#_tgrvZ{P)}|iUzmQY;Q|w!o*Z##p;T;dJBy}upShxC{R`V;@o^1~w6>qH9zmljh zBz#2YAa8pyBk4w`(pNfzw9CeoH@&d#G4QibU1pS z>6HZKQS#|pQgyY%mwb%)94{^Ju=?A{w+e@&qno#EjmkUWD>W;=B{R_B!c(amvrD4h z`MgQpt7dThc8s-Mm3s7#tLE3uCR^Ekc~`j9ThcyScYA0_Sgy|2+Y`dHevmXyIMI99Tm1=-Shd1znpU$M1xLs7?1eSHh-m*?!NgE*Nv0|bz7ej zu5=r7lGE$^iOc`^a>?&Tp9%e(N^>R+lJlH>-%DJ6wtswY+nka&0Y>W{6b$wa2|2lP znd{!iZLf|Sn0H*Z+7(wgTxO-(V9j?kZ3i2;YrRic`1R7O;M@Jed1j@Hj{E;gE?X9O zTv9>YzU`RK<_~F)R(xsuzZm=Pc&_*Vj~_qhw4`KIX2{4YWJ^Yrk-f{V%#5V$Ss`SU zBq4;%NM^_=l~H8xtZcIP{@z~i&-eS^?~k9$xtz;+_de(4^?W=ZkNe|3uQ!jLw5>+! zXMrRk=bi|t)tT?;hV9~3Jd#7NMfJ?LQ~Ul!wNuP=(Y;%;IkuS<&7FPjXtl ziHSifmP&=VI4svaQ)Jj8j-@|+r}n#}U;E}T!$Cy3H}VMgVC4PN1NDzjbNpwZcJXPT zs)ov!$gi%E<(4l(ZE91J6C%f16MGxZmsO2hvUxovBaT^k6?wlYno(qUD<@gaarybR zvSiBf)a_=;r!tAv=R6u(zKCpmKBUs_M#S{6(4%de#Jt-wC!r_hnNmYm-RUIWjy-3+Jx6o$)3y_CD%Yn&_Y=&d_A&T1Yn(zLdiEhhpF!X`Dg4$NTaOitv%#zStL+ zzp_PrUXKq;ec8jj&F=fX@IGm!<8r>nlEuw&`U`U#Em^JCy!R-=S$)69mCyQS1Pkw! z-Y>WPpw{@=b@PXaY@z>Ej{yxXGso2azfO#DzkCwkpS#RMHy>AOm|*X{UfDQwot<^p zVJ`i_l6o7&!Wu~B-@ z;3HP0o*wHf!>?VpwUT^?t|SMRuW z=PtCa+!0->^Gz;NncmU30$Fc5v1RkI&kGV{%|(pqKsvBrJkn_QA^l>$UBSIb*t#i`+zy=7aPFyLj(pRwS0ZHocLRW^X$= zI8IAF(0^Sc+{1gsl2e{IQl->OIU#2UG>Tx@_g5U4FH~uq0m#tp= zn1mw9VopewDaryiCYJN(!7W7rHThp2{oOw9&4s2ihosno4Z=N)GoCT|jkN5(e&JQ2 zp|l)Ye^FMbyH3iS9pQLyP(&t|!b+HaJylm2Ewjyq!!WAcg|buB@AF43yxd;;u4%K} z-%`~enXOEDG2kw%?z(czlL&fK!`x~=+S_k-T_fsR=srpZh8Jwtox z9x5cO!MyZaw*8xX`)hK(Dyi-j>lt;v_u`hB_%Ua~Q?#;3Z1P~QUP>5sU&rA6u`ZasaMvPOMkBEiK$Dh!)8QiNxJP4>!@>u+e71-dtzm5gsr zO(ty47RL9kET~rG{#mWY@hF}uWnI#~^w-*>k#njwHX@IdDq5529e2qc+o4(^!D;L4 znHk&Nnfx?4iFZVa(n0HkJ`9X60^5Crbp}2cE?U!AR#XQFaFBPNc)KB+b3*I@vL?9o zE)hjq#8t$k+zc!%@Gg#V~wm!zt@U8$|^tRFIM|Mc%>uv-=||5 zg%JZsOKfr8c>7)7=)UWA;PJ1+GxiSosaBTmB{&6%Vz+YWJsOl!OHw`V+)7uoSg2Voe^)d&VzF6lVp) z9Y_^b670jstZlN^O`GFJe6k`0_*B|AxFRVk9T@6nUAZ&yEw9QyPBjhV3Mcq))$ALj3L!$tiUtguZ9@w5Uhxy!Mt!(MIkqOv&@?Xs6Hc z5K&=H`|zS=_Le$Z(!2HW+nT-O*?E%Nl&M8y9dFj&9w2X963bm`<4I&#&Yx7nS|i-zSCFN&V?djXmIP zb*4fQfi8T#n5V>U8Cy%5HGSh`2wc7_OcUoYFz9W%WgNog71QfhvflV=um0E4R)2Xr z#pQibctx5be>Gag=`3WZ{GNr{q&IQ?!}}$iBZ_x|Oowj1!9y&D?VZx=u&v1D0)gwV zof|Ha$`B>qbYZ+|e@I3lfmdBh>t87Nx)xtH(@d$7QZULc#9zq!Aok#osJNE#V$<_W zS(@Y-oy^pj?ajQSw|#gA>R(>kFB6zy`G9PSzK6z!8vC#Rmjzwt`6Gu0#(HzZB~2bE zJkRDkW#D5k?_+-`R4(dv1ABbmV3^#+Dn0rN+w`Tafv+mm@?B#=Valg z&QP){w`YO9PiJ~2CL|g2CiWbv%Dhjv%zP`}NxSf{$P5Q%X;W5eP$Eh9c6+2XSnp&C- zQ;?*gFMK<=cysawqfBibyBrVQ=g`ZKPZOOT|GZjF>>xjU{lfSJ9r`aW_jU~Pr{s%H zI0x1XkSn^2ZKv{NbCwAxFarlY zRGZ6dtY=N-TwF}%qW-iZ3{0Pl1zqgLkVt(Kmm@NPUs^5`S>f9 zn?cY0n#Q5hraY5vp=+Tod*9aTJ^0d3{WCZ>Idm*MVDf}c>~%4RqYdsCV-f0GPA!DZ z^`ys-o#vR;pa>^QM56jRYr~AVeTc+6x#pGO5~|e(pVg&TXLuPm1`ckw_O_nk=g(J> zod0q?GTihFZz{K~MC=`rnI@T*)-W>ud}po72w5WX{ok)lh~K?9MwcqOmA)m9_%)0gd^(tRYCAdT?fpxJY3UYb z+%NRiEc2ra=|aC~R+b5WTuD;Dz5Zu%(mJ(;I#rivisw#@jHQH7SErZl8r!yBeC~Ox z>SUKq=Qla;{I$c8Pig3wUni;9{O=`+q9VUqq*;8+@Zyf?H8+3Wi97EYh0A&rMfWFC2`ip1u7ReRrYnLzg+z3-0;L&>~iBTfuwh% zsheVxbJJ5+nS%KajL}4rDen4dHnWq~Y(&g!e;#%unmXKgKC7#B_F9Ae6Si&rkT)T_ z#IhH;w^G#8Su-fQ+;6Kr{5~4abuVkT_@967rZIG6T$W2}jxJ&#hndWIcYx__BNfMp z?^IExxG7UyU5QS^n`eb0IKjf2we3OY_(r%RG`)Ztuh7 z`w_dpfAy;^-B%djzPe_QpP|Hzrl&ab|u#(O9!82H5^b6@+0IN9}| zR#N)wfq4;Y6LirYe)aSs0$(S*)r;NUrN|8r*TsC+&CopUH*a;+lw10;PY+Uq6ep6h z6K&0`;kx8~k#wSy=Zc&ip%bv4J5z{s8&@Px8NCp*pP%Y35tChhERH0gW`e(#^9S_B);ClP9v-*1e=k`zC-MLF z0(|(e)>?KWq<+Tfwjv(~`MB}6+lWu`4&NQoeus#OUt|mjDTSj>x=pH(q7E4;tluFrl_X zqIp6GWUSugC}lj$4GYY1v<^&hXM4HZ-Nx6T|71aD{XL0uk z&9?OWV;^fA*4}AyA{|tpW!Szl;CL|B^alUA0}}7U-=^8!Qag>nqvfknu$~X3`%SS!@)H$N-!4D$(*5b+Kas4DRkXfj@sp!(e#ENNsWum%cVjro$5Rh!>Hh1t zrx`5fXgH%t%;LAavy8My&UnakBUhji?_u2cxU2q$4mnv`iJrXH$zvsL1&rEA`B^!# z-AQUP^h!>t_$J_4;i%ryp>AO@Gs{>mBf!}ZI!ENKGRGjE_QZ{xl!5==VHpmqYSaIv zUU8K_i&k4*bIuHqk1Lu;*^OVVUy@?>Gu$_Pc8Mq5htux&*NLOf=H5R7`*srtf86YR z;FrEZH=!(nJ2%BQLB6Z(dA*h>-O6ZmgV2VgMsUiG#EeQL>}3LHFx?M$LZe1x{Nrc$eQiK`dmw`*?FPBqs5)-p95`B3p4p zz@_@u9%8ZgdH$k2yj*~9YaW-%*Y$7iJA`ruEITARo;$_Uva^;3Y3D-Y?|w-nHaVhgbRL(bp%`MN98JU?cM zC>tx`zkHXFDeaRCkFLjnxHh7rQ=(o%yEQi|i_?-hhwDYAs*5+Ge9)!SH=x2{7?DsgK7&)&V7&G3)#q?!Exlduw}oG^-&w9^dvb>|(E11*zY{`p z{(^;rw`uFo+AK7~Be=bit z=%p|J$F<|DS4l%l{9@wwi99rPJ%NH~?uzr(n`n{I6md>-O5xQXt zvbDb6`&XZzD%x+a>$p=KVSkRH`mL7s{h`4C*TBP%-dyk6S`IQGm$B%W6+ixN=}`YQ z_F2kT>epf<(+lY{WrSUA*HSr`j9>BZ$r49)boS|c?zmQWOR=$E{5aP)e?cwam*;au zQ(T=~OFV!0a7G^s*cA3m%{XaqhRYfxHEaS|MjIEcQ~9u_QpXml}z3%x_;_u z>6?qWjrI+X9Bf$kx485W|@ZXHe*j=W1(GPgTS~#g9Mp z^)qFtjj0cv@oS1SH95NMRDza>BQZhf!n0M~9rarAmDlT_A@-*wi=h7WSIF;~t@?AP z-L!3`N+j#>*sE#WH}m6(6Lth;6jHRYOY6kfB$%(##!XI0aG%{s(;$6`D}N%OJ6GJ3z|mNW75w5SwO7R;}9YjI*o6 zCZ9zL@4R$m@l$xNys?tZxg;_{Lmlxp^Zl~>RHE8~Wy-$t8W+2m(z{xJ^F)YR%O+JS%m8D<9BV7B)ca29jD6>ZB$L$*9F)?{Yh zoTPci>XvrDUo%}0b(Wyw8C6qh;>gsRQ6={tr|r{miD_w#bJKUqNxjtNq&Oa(CgZbJ z=MvL;s%p8q<}{Q+rs}m<>OK}2PB(qWOfNl(hCoog*hP(=GK&3CmRwQCLYxZoOFW7FTLyYB zOnuMr;*QIxX=#lL>}~Hm6z6H0G-kiPkt~$Q$3Ay1y+G%51S0A99{AWFeAi{Cqv*nT-$a z+ZO7x#+$oxJvp$OVX>zPAJ}`rns3r!L($(*tc&Wz+N-%2=2COllO^l$fL{eBsa%x* zzE3f)+%J9isPg7ksoZwOMd1FmcBBV}NyszhMXO?5o!ck<~(*(&4Ob9=2y5ovyNw(24fR%}!M2 znX&v(tXjc$?{=j-|ILO__8Vq;A^Kv;a;{T%xt4yD#w;Bh&lCg(|g z9_r`Xp*pi%DYx?eoG&RbQQ}BH^G{1isoY^319BM#yY+U-~HDPaP@46L%>~isw12{4q(1=bB7si{0yaEzIpE zg=Qpr<+zCocXJybC}uP*R@fNouPQwsNncYFj+!IitZhHO=eth*&-tL;O(yK(~%&f z-Rhtmliedn!{tmrpj={|S)t)DT;==g$UlR;{?_K^P6N)>W?=)#r?zj0zs)QfvA#B~ z!WH_kUd;D9;SWX?XWt%U_~$&u)yyfy)*10CIH3nwlb^0;U!)9> z&(wB0(-}!Tll*ny?90`P$Em_Ng?zK0>;8EV5Toq8cZVMwtckfMN_6mjKJOy7u#!L} zP39t@#re-*BWVwzKeKpzmwAEj(F9m5IonRD$NsOo^M>aCUu~SPJk@M3ZoJ7q@9_Vt z*K?Tg`LE0{%Yxy%Ix`RmF^(bm?8IlIKT8{!%l^MPTv9P!AM(38Qii+4OpEs$C8!{I zgs~fdLRtLQJ$ErM@`!A+MrxPu{Q3De_tQuMzVZQsepQo|0_rU`-On4xW22*?f9?!1 zXvIg=ajNV@w|-xxhrS>*^Xklq{r}@-BKprV18kSVLLHuj`V13b8k;d=2b=~0Kpyex z?~Xi~AAogHKkaO7PwAaPXRVsViZ;6P(WhU?3Thh12|IfaJje$M@KrG9h;NU92XwD< zMXE+zL9b+_qEJUQ=G#0aD;pR>%l0&>G~0*ZDRKe387Grea|k6AKI8UQaqlnD6t-@Z@R$`xz)+n+yjMuWWw z#O(2;;4L_A2lAjR91k1(9csQavmLiP?yyB%#h~<#JJ6FFQp{05<@=R3JS#4po|u_| z=3eVUuRPiQP&vBkqI!#bjGJJlzrgSbCNUu%A|xPinF!>$3XIz0f5t&iLQ}H^${CJz zzO#56RYLMB+v+h{EhKgM@-*Qg_2 zL<|(DRvoVP>*OV2&EE%FO_;8DQ$zL&^R}1@UwMwnO$=+{Ws>w_^q;`yfJVU{00Eyc za)D8IHBsq$w6bF?4wE<-Io{aV=(!KDcn^qC;P_W&W)d@vOio6>dS!L<=EKe(ql*x0 zh!c0yq`$O}*gp@-F~BT9?y-yjX$8Ixn#C?LZBSYPva-7vE(H`1(Z&9s0)@C11zHxQj$&M90Kn zlL1C&W^PUo+*{S;#IO4Lb9{WwH8t|(Zy|n(SrJT1#l({DhhSbg2BcONFh~wwXgr$C zcMA1$Z^=437QFPrlmnHqtMpD9rujh$XdlpuR(|>X`65Oaj(wtF+vy+&uE%}F+V2^J zwlEnhLgL77E~ZONifwuI)-7mc@=BAr1_AVPcksb#Z)$XOH)a;VdZ?B6!q8omvHR^d zHLU6)CLGoW_Dd^cL=Tt%JdJCRVoE6{Jiy+lA-2P87HFWD(txg_4+MS)jGNxm(6s>E zp`(LR$|?9t5;!EKsi_ttO!mKv$pcI<+Ql5M@H1TX)$AcKzCPkx0*eDU6Y4H(pP(s+ z$&g6J(em0Z`kVRPAdUovhQiTzy^0}5P)C473b4DNcdZbN2>|b~HSrISV#i2lhPf#; zbAWbFRPO>85zMfgHclZtyT35sD@YhtSR()eI(J8Eh8KkG`|F49o*e~F$ z(*3)4yYkF~E#~C+`PP%$!{?!RYR<6Yi9?NaZF6id;;baFU~ANT>xz>5Y}VwNi1fFJ^ZkxQaPKobGA1uzJ}Nlcc# zs2K*T$nQ9FwS+Z}?0d9wUET%A-{az@J)^18QxDSAPp&HFE zo1%$Xxxgy~RXY$>!21ST1(wL?x+6xeb+h$_%zvElBfI(>)3_`)`@K9pG1msb3})n^ zUV~c__l$|hOkqyNruFsEEeyER9F@bJfXivl$n+n2(r0qX61d4=k%5v0IN<2PuQfI7 z!onRbEsGIql5QU~_~TT4W(qrT#z5u()%As93b+h-(V!53fCj+G+meO1hV1~mfa-jq z!tKfTKMo1TK>M?tAeJ{uF?C!JImER17Dxtc&~?V1kHC*y2>41KbcI0h@-~!&_ULXZ zR86(%b@RGyw}TfkBb71tef}Wq>6=VI?*e53*dl3xoDW-nt}q6 z?-zBSUfh%aQ{C?0+Zouy7}BHD81#kNrDSCEffy+!2gxu*1ZIqmRP1gaxIZ55?m!QL zb4f4HyfZ73aABa*Dc>XAzHjl0=En6?)Xv>52li8) zAX_^l6eURBCK$z^U)p^y`%Yh;^3QEuYiFg)D)$dJNbphp=5vuH-+xjzTz1vYuiM+F zZ2Gr=!1*wX=*lg3$#qk(lbJEdt2x(~ITGY=b3`oNNsUN8$0}F$gNm-AGAVkwL42MB|}KKEbKJxeUkSL%z76J#du z(mJOB^<>Z1?R{@_v0`Y8bgDTyz>p`c=T?kNz&FBJWTX#ZUMcCorzosDShq^yIFB|K zT(`zlv_Lh+dpG_DV@Mej&{Hd$7*&vG;W_8b@TnSiz0$<0ijw={Vt#QDW z5kSf~W-02vYz2~x0xK~QQSrcByLBN{Ctu@$0*ymJ+*LCAoWED{+Ij-)p!J4#y6?c( z`$2u==+P0a`pF_D00CS{Bch_@XZBTT70W`p4a1^X)A1uuu-6{@`K7l{Vfu3^AG;*D zqSSt%V*6Wwu_NS!G9^yx-p=3ZPRr&V${EMTnVZXLw!#$W*R)R)E1TdBlp8~nz4K4W zZ&(Z<3?miW_1O-#GhfryjoFw2Zkn$ezNMcY8Tc5$N8MrzxPqxY>@b4DjL*)r3H&Kr zz6(O);)ZMypi)V9g<#9Q^>QBH|77*-l>tEKgb4*4jD-!{flx%+fHnLRBQPePOsJ`s!61 zH^&93+1xDIG0|UDp5Nfj{tt~GvXCx4*@W* zltihYr^R-gfG4jH zzrh422xvqouz+9!kN6^>H1OoB#BF;VLEL6 z>4(krAacQ{s$!Js`Dm}EA%l<+dz%!}y=AjFb2wxA{mH-r6hZR}Cjp#x#5GcTzo&}V z##~-2uxgAhBV-14oAlf5 z8lwu&ZjRw;7?BrG6!#>5CtYR^SB@19eY(Z2{Hh5qn}&w8Q_9P~T1!f&!2yGJhjCV* zI527oEKhBHy{*G{%;aN_5k&9Vzr!@Z-wYkgEh-}Lr@@*6PY)-@04k91oR{+9tY)@P zAta^iqrh$XzYDV)oG;>h^ICvrvBE&yHrQLTy!5MuKLcyG{0y>k<{^jG{>{)9vuZk> ztxJ%lh`u_%w6F7+U7mqvvn!W);jIf?t9eeT1>JCvBpmqmg@nn*m}{W_tQynpXV# zgoFgmjB{|O5nMpbyihcUjktM}@cu)8X(_#Y86M6=84MrGFCjA{ zgA=+|AVb8JtchoSzqJ76dN5>+3=IjZ10DnAQUwsR0j5m6ZaBf$DGH%9egz ztjd=F4GXxhX9}$D<)FIq;o@LpOEvg-XK$m?$Il@oAz?sjZ$oYBah7HVAT1f&DQ%i5 znl#6ce;xQGCo4O-;$FddACyk)Nd4@W?=QQTOX3W}B#MdYzxptcp~b$sH5#GLceBSM zChWp))$NV{7~fo^cS;w#Hxd}_mM)cLoKuOoj1YdSK&D0I%wHJ1+p1y1scH$T$91?O zWfCEm4Yk10BiF+JUTZg1$1GCwxU-xaFSrNQJg$^#oS*j&~bK&QgEsFebDWQ}Q4JTj?W}u)r4ebHHt!ergL`2*#6ypLSpn5Q;U&{}Byy zetnzgbi<`JQgQt^c&Dmq!=k@{5y6FuL1)-HkyZvOIy$gWUdP7%?7WD$y$nP=?qA9q zN2`Y220uRx$%T#)U>0G+2v{J_8=~Kl{#1Bt$MrZMT{!F~@9_D{EC3~-nR4662q~gC zV3a{40db@F{sCV9V(}#r-!k6#b#<1WX4<+6`#O< z>b@hWoq&0dpIf*7z?Jo<^BXP(TBZUfG=!j(07Ww>7Q(1`@2qrI8MrIG?6_t082@d4 zzCfJeUt{?~L!Jaa(}O9a6OLkt2PP)0k_~jLUm-R@l0`8ExsiWN>k(8y5X{y7#+3|h zD7fS$Cf;m_5dx`EAnr@m#G~t))iK?f+1VqeL%B!8Sy)&Q2xZ7thv4cVR|9PUzST+Gok%btkz|mf2_Zf}S1+)#ju3?P06JO!9GEw!20d<| z2q2LSqAXYz1dsANvW~-{BH&^=PBbwA-2kQ#e>LIrZnD(Ew}KPRrc8g);m;Lh2xMfc zkTuDuu4-?Oj1h*s_m@VwNRm{4dKGdxecb=v%_=ulR6?UGdQa7;K*oKz7}s-KXD7tH z9YK=9-=ihH#7IuEzIKBsCTbb+U2V*3`dvg*rpD;cPdwaE$2Vedbt$qa-qynpc zft3bGf;A2}e$c+g-vMLQ&S_1O_C3faeKL=xzrP=I?J1gv&mK+RCx83)B%GKNbTue-ghCnFzEnA7*}|@!C3`lgxEg`Pw*nb z1!{|1b_VA-IA9;bf`fm)k@iJqDXVk(LKvO6>xall9q1|h|HmpO<_>~0c%xXl6p~lq z(<5xcO@}*vYbFRx=o+|J(<__MBx1lCW$PjViB})q6Bioe$Y@)G3{TVdYR_XSAP@u6 z0`$F|UG_%!@q@=saZm&iw6C*%Cz)myvZ!NZfAn4^jaQ(oJRn6cIJMqtP6CVbmD&Tx z-#_ZE(rU)WnWNO_-ymkg_W_dzIk&tk%*a^+`Fk%o$``qH94)BzyeTg300P3 z^IJ4j?sA<%w%2Q70$B;!!$bS`!$X5$7U4|eJ{_C(daHT+%9Z`5%Gb49N(|v~y?D_E zFpron=s^fO&WRs`^bVr=iFf`JhGX_FocFwNOtyD+kTFtuI)dYlTodr$hUM-MD|dhf z6!s#%N4e_)bLDr2^(?`ihmOiI!l(GEVf=wa-q7SgSt8mpzIl2DuqIKbsS{3Fr$p|& zh6PYm`~kmfgEWjmq7*L;ydR??ThOGUmD#|doWiqWZToRqaZPcX#E|>T@IieB)K~Ds zh(#J0+gap7zzealvQmIj(XU@9pQK7=E%TU+sXMPD5CbnpE`(Zo>s?~v5bjaXiUrGk zq3e~T%w~{pp<)x7@=~6Pa#lsaKxoe4m$mn~%VyQrfm4#WJc9T~^0Dau=t{01@ARc7 z1uMT6)!>Pm>-!Wnr)C5p*Bf6R0rA=eQJguxL2R)9T!w;f_PYii-($kB*cS@08 zgZ^u;eI+)}t%~TJH*X&F6#8pqD6l5QNL5TC&M!jp7I!?>Q>Z;mTStc=f{RT4moH4n zPI_N9D>s(HO$g5)$lM6jcVIV>a)r?&Y^5^S-1$y?Kn*0R;m4$mzo-GJ$4^=iz@5m1gK0{td?>Fxhd%1X3o}#GO%weL=R7pJMVy5)O!forWuV0OB(Q@OV#g$NIEo?>;%K`3|ifMxai zxeGvS@cHQ}DE^F&wr3lN?Cy34YG$0`yX9w`LpA0W0rVvDrG)tQ!r+E(DweCc={RB; zL4GLh4T9TcV{s&eRwxsTfz#!A{LI#ls?+BO{{s1*q6xH4HL_Q|y&gYaa3XcNtw@Rb z;&cZ$GqcQfy|2{CTyje+IY%O2yzqoy-F8e}N2l=CcSP|t3=EP8n89g`-^tBySv~|^ zXaorO*5KL^-RC|*eBjKZC86Ki0E~~!izHxuflq_2AR#3km!hVoRtO@f-koR#>lrun zdpIR{bDn!rN202yyYtyY3*|4YA=W4ms@LK;(Z;oi7*_LGh<>&ttib*I`5Fccqclt>3lw&qYOrsWyM{>=K?EjC>uChkk|HXw}*oY=U#!;vx?Fb`}i-sf)UZ# zFYz(qk&!D%i*iINJ`tui#hzLVaV3g7D%UwEmG`-$iA=^qHUH+<$4+zoJCGt%dU>wI za_--$hvCuDx7&*Eqj*csB8NQ$fzRYqS+%vb`s4}!RJ?^bKPkc1){rv^y#8Tsr;1hX2SmJE?jpgVi9B}kj@j+=cxRHBAZ-XFBq=Sj_lIE z6{O7O0d-|W(}dJTvV+}i!LLb4%r`B zP8Z1UNH^}1(`h@SUP5#>Ux>B=DO|Y2sO@pG72UmE285QP+N}^khpc_AXQCFN4T|gy z$}%DMzHNAt?uw_pi!|hcCZr)Kf_(!Bw8}tT0|ywk-Lw1?v>Bm;iPLq4uaq-+qN=L= zw)Z>9&(_;s;_sOj_br>IgcrZld?D`F764(g(8TRn|lV=J##n*BL3QEOel+dxn|;JJ{0(;D){$hy*J_5=kU5l z#>YQ)@Qe))3ph@wmmBZFRo8DB2z#X8>TuUdtPk$n31WX#nC2G;hlawRJv&VqOv)&E zuHNld2y!QoutppXcqv_9D=5uy)Z8WYY;0~hvR)4>6>So2s+w%8R#74GF+}4B&bP1c zK8c6*6GjkAN63bVrejgl^>V;5Al?A(G{*ZspE|x&ubXxKOVoGTPy&h#IUcBLBRWZ@ zcNqNqZ)WFqXXUQb?zUUDE*(8R`T-0yG=~2Ewe1<;Kxs@=v+pTPxEw=YD?VtFnJ7l{{e;?!UDh&%qfHzXJbu;zrTPy-Pb&>&|Y(;VSPu@Ctok zNGYREQv?}%F35rbk2h=A^+9_5B^s^ZMH%Oql-x_%D6wF&3!t=n2;%KcrV@C>xcL12{FeWAN9eAAO$f6Od$Emh0CY+?pP(MT zP&{K^`D*T>6+@r+*%QQKeGGrI1ef-K0X#Z9e4LtE6xp=@gGS=<--A;Og*~1qy6g)Q z5w+jH%OGD4$qN)uly?0<&AXSSBV!Z30Ac)`KHU6A9 z%_kdm^d_B=jTJ9Qi^GVP$9?Bemxcv6N-tJG$eWMe_+9L9uf!7>r8Ob`RC4b_dc`6Q z_U$!2!?etM4vsr__?{&N!3FiK?#H?#UA46~r`DHW{glxMzXLAStE*yy!`+q%h57kg zpx13=#=v=ilR-$Jz&{&de3-16G4Z&4Vp!1e@rQ^)|089*sD3bq*D6<=AvY__+>98B z+mrG!z=$>}j3Nil$T6{*a5p2e!Tm`rr4Xlpnn&~DIsugufgG%~1w-IGs@!PwnM(dj zZ8LM&ayi9QkOeSPcTFwzJ9XDR5S_qHmDBs_m-!5t)Em3UxcT@NGScETGqk~4Owp*9 z56#ri?%=orFS`~Ag(pw`%;_wPV2macY>DKm7%u&OD2PBu7bh(IAXicjRBNmj&cA`+7UM@h z9vRY8j1e8qr@mVv5`;M+q!kfM3R=`tx001qRK)9>63QdENL>cJkaf#L&X+s>F_sJU zPNaQPgNz-?igX>%f016dJVDQ$cQ+vh^uLwg_Zy+k5#0l=$thHhCpGO3i&3l_UFIx% zId>;)m}9O4B>@z-g%xknK%rGL<2O#%z5(r1sO=#=_+4WJ{8D50$ur0*Ai{`Q?R(qW zy&Pl($L`4C!{n4Uuf;jHWDpql`oyxW8PmTgi% zNUuY#1Gd@9io-k&#C&ccp~z#ZLouj=m?5`_myiaHvUV8rR(CS}aySZ<7ZZpY4V7M2 zze`VlQ@aa$ir@mNBGN-7w2t$G(n}6pkr!TuY+Op2z{L>Cl}2Eh!;X6MB88ZX6O8p- zE3+#)U;5#$eat}S)55(zvS=RJRL;jCeC`Yu2=h!$S)?5H5G$|Ett~E2Awz`}2h_RY zODZVTJ~?@IcG?5#5J=Gy64^2Qgc$ny^Ohnz1N2tVGD8x!1yPZYDEm{*?)9Y~iwF9t zXdg8%bNx!zcWz~>*4F}$R-Hrt_n;@$@#IO|(a2Vj5EHwD7)`KiLxXuCxr>YUk~gX2 z&#aBg6sQ?^xS-}BXmxOh09gGfnk1R+gV06q)$|qYU2AKu8V6vQk;Y!un^LJ_)J78r zipK~b2;^-fT{mnaaNrOr?fe^0)t7krP&M`5gTpeUSEJA0lu=eb>aSHG|0`a?=7s%N zfxZV8>JM+!eRZt2va2}^pE@-&^9SrcszCMo>TqpgJ=s-a(CTJN%iPN1L2X)ff|h~) zN%l#4uE;0CeF@r*1+`UP*41!uJ~>X_L-h&r6hiG(hy> z1N>z9tC1|HF*=QosgTti$l2?RZ*6T=qIY21aq3n(6tIBo$L3vB3$iC)Q4+{DLrNKG zr-IzvM4q>)j9KaF9|{XkhBe7g7z$^e!&^ql`yv8%z*%8Umfa=n%BAw6HezIYiq<@I z`p-#inP3@HVbhYwM}*e9H`;f$r`N z9QWfqPjv0^o~u0XFeI$m54GrPYiKwiLM|>YMk0h#O^snCH2-aV;-im+<+j* zk`Eaoyy`Ue2u%q1<&G9VF;Uom9bH|En~pb!`UNX>KCyYZH!`k&T&bjE);dn-q$H9_ z@n#TFKms^&vkOaI9mW^(OD zFL|5pfqjhl2X$HDRvjz60CXAg(_0q^!6!o?@45Metn!QhaTLc93cge4Ajn^!Sp^A0 zq^`sJlSpO<2O$u>^AGu}3?w68P3~MB_d&afmsfdQoDsCEqDz^}jZJiPM*sYg34K-d z@m!X`hrGO`dp{g|Tjmd7`33TyL5u_ER-n3KXSJ8I<}Ccvw~UZ}ur6pY*fK6RZnUh0 zhXfU!(ed#mma2mX4!F6xqBdpOoo7yJ?lPb4A=f(T{&_9`t=4tsaBA~K3F;m_c-*N6 z9{GnjzmAH!21^CqKb0IFn10CJqR1~Zf-)2c!pMwNrS~PQ^2_SWas2xc@y_=G3*vBt zvtsMMv5hT}qLQUac8h{nv`eh-ZjkW|>kotT{> z3N^{1>+g%4c?7E0GgT1Q;vfTsncOzBCDZ$%sOaU7x-GG@+#3%GO=I54D-+GWG^?_S zuRD>Vh|TCoY+qVlMuC{W(*5N*6}AYCO=)-cQjOU7Xn|{iVR8b)p^NGLiG}6L++2z& z&u0VSB`L?UnDNs%HM|Y{UG@Juyv{16<8xbEn*&~@3@cUPiYd5Q_`8tLe~$*OGju>DlG&rXa%h(*7r! z@fh~}I{Kn4S`LO3?He^^Wn-gVi9bd}G{3sqG3h9G`v>ZVC`jeyEqk}^Yu%|jdI;pu z1BftPmK_>^g35)_F^8dm0goFWh08OqQ1WI-F5Fa$=X}sQN#F91Hb53fRi}3Q^!dMQ zLmoo3q4QHb6hRE@7aAQEVkNV}?>VjOTldflD*H{k_s>PWjEd3%e&Ob7istQLr_iV)eg(OE?+9<%=gujtviM4y7P96c%`#SIhC>IV>dep_mI+52X;Cy& z=y#v5B;o8!mM2YKd?}|UAW-xel|TToA!LK&aL%)Z%jULkYxa@Ko7_@oxwjvtA_q#i zW7lSlZeCT8PSjv7o)l~H#>0GS=Cd!rZp$tyAr!z6cZ+y69ZXp$a2<1FE)9xkpVpGl1jN;yIv^2o=j;Ry$l@C@mtr%l#{+Gj6V zw{g(Ef8FHU3Ze83>+eYQ!t^H>trL0(;zobIDJ>PtBd<)W`&eSKd3`@f1B_AGi~JwLr2JyL+3r2^}K z%&Wew;82}e9y%f@hc&o(V8i?~R1YBy&du%Q;6Q^u-aSr#8R>n0=XyWp6w*$^RG?KN zY&hOy_QH>U(d#Yvs`(Y4sozsBY&E2h`)zW{E!>kNPar~Qn{>ossv6Yfiu{Ft2g!ZU zD+>D(q1d%F%|;P~9;B}hFlgf;i&sz4)8nMc23?WaEB*9I3NJJy7g=-KlTBDiu-K5K zu;vbZi>y{o`D)Dn$WK?4*Z>3Fr`xXjpCn+goju|#mql`Auoj;Dx zNB1i#I)0qKjz(+q^Mk(P(aLhe!TB#^w&sF1l}Kp#ueeIA!hVU-NuC`WcQv;^;%_jg11!W9J3)|@#cKrB65GK?1jy&KufnG1)+ zYaF0Bn#BHf;hOr7ZSOxp?CW#pAt**^a`{hYtD?oWvQkQ-zU2FWrmilsXi##8ji~W) zLa$Jt=AxRL>x66o4drZW2aIhqi~d0p+AHt4D1xm1+I6&y-e^nLI3Y2=S6W&M_VBJA zDR$Yr>}+&~z717|u=nKT!uI+FG>9N~r1P#Fy&JaNXMyS#`skO0kX31Ku~#~|+10~? zosDf#{sq3u$wkjX4kr?*#Y-X5g2z`c_zNCzF$gm_}f#Jg3`O# zeK5HFAD$+8kvONL51&RhgASB~C0d0isa> zjD3L}Y2SyREmCYGzYs3;Z}#{1Xx-KoswCeXW9Ync1UZ=2hu3DWA;m+ucU)$-(OHJ> zC|zys^^J|@)=Xp!Yp+qCpro`gGeeiDCH%$ZWg~Q4L4Y-EST#YC+cxtTMA5LxAU1(a z3j9N*IAsF19tj3ij*-9TuSVnnZM0R~dq^QkF*859Glmush;hJABostGBcX{B&jNx) zm8(}Lv0c$4mX>Czpb)5C;7$qD^Fys3LA2Aj(z`r*uBYI()?LES zJR)&OW8}0>qT-!?|1&z#V7xL4!yh{ek?=6SMDr0c%>?rbg;+u-BzkiqBD7%4(eeu~ za$UFqaV?{SyC4(O5d8JVhK6hm6dXQ`2IKUc91(ikLYOURE1}oj&CLzzp!4U|*&@*F zPpflMDWwJ0mp}h2l{ngUp-qVbH_W>>%?&eD(N6TdeDMPMJ`RZcc+2KxP|<=gC6d1A zen8_ka^!Z7C5mu&6cparJWx&R%`^WAr`Atk1S$N{(dIL6VX5%T@e9vULRJ*tqM&PE zT^WoDd=o!^e=oF&;uWW-r=!M)G|#Qvb+jN=b+jYT;Ep2j)>dP*0F;2zDk>z;pKoey z%}0aqzmK_p*49q+6uv_>17aQog+cs2h^xVJvEcsA0fHCt6|jKw0~0Bhj11IyB4 zB?z4iZWE?bpFQuXL*h7S_;QGQ@af2C2VJ>0R!`F^`Nh#l+;#p}-Ufq7Z zelj%$ggo`A8FJKX&ScvnDE>A-1BzMPUUa|TQo(Vgq8S;DRu|P}7 z2Sfm-^~?7lVYkp@lY@na7AKxg%do<2$wD$c>9Q9#1G+&8m1ktz)&uHhx^jK_GYDQ1 z3tB2XP-QYO%tgyc-~BSBMCgiO$Odm-HAe9ENNqUMo5(4^B4G8fbv>;Xkzdvf4T*?o zHLx!?FMW+DL|r?x-MSL$kC=kUz?Z|4AvF+#UEBBCzNerrUCJFcBjWk<_3Ptp8CnBA zd*!1#wIMLUc|%_V)+SroAwF1F7gTwScrKdz!ddRUha(VuLPa`JEDA|H7^8j;H!t|q zgqCo~z~D75EiGZ6U>w2D?*D1;%HyG4|GjfsbW$l~OPo?dNlZ$F5=mK(Hlj!)d$f?# zWaykqsmT<{8ihz`NRz@0k&!+7QYcIIEn{cqzQ^gG`+D8`$GyMbfA@O%N5(YY@AG`0 z=ks|l&s%BHL#VwTS|ao8^~KI&34*-`_^|SW4-P=i3ugh+-IGW;0iF(BEE%^1f;j}mLEEo1_y81oQERIX>@{d}z(IJ7 zK~Ii7M5X$4gQP&LfPEvcgOQlHwr|bcS|F>cg%_|D+iPT5=+f$W>=3x#RXf^%)ML6= z$~Ura$@lD_Mud}6e#9k*jPHZZPcOr3nKb9-y_JVx3A-5y)M z5*rmL0^>YRuw36BxhQ870MhIJ13(HFvcoHW5e+ct z8Hm8fY7Efi=(4){dY)S_FY64G_{d3Cd)4vbE+im<(tL{wJIN;44W_GOaQBkf?! zHwYQJxea31!uBHb`)gi45TuZY4-2VDm=$DKcH2LS44xQUAVEY^4gmyDouH_w;XRq^ zw0{z{8^VQQ=}A7S*u!H`M`R%J$LDpfj=y(auM+4m6MVIYM2zAvh6VXb;AVyyyOd5^ zlF4(p(~xfu!!D=2mk0FVibXHqe+$4Ra@VCOU(OZifFM+oH1UCT4t6WCT-gb}`%5n{Z3zY&7KEo^U z8>N5~C8&5)GcusHeb<#rgUkvHtM3TyGz9_NkU7JM*fe3Oow@B?5MZT4ek!k4-g|od z#0jVAk-A@wkSaOZo$J_M}Wm}4&dDrHc`KrH^)gF zU60ulrrgN+v(k1~$wZ>Ug9C^s(ye%w)fE6mHjw$)%21h_QVd9QGb1J@Z}#2nC#11J zP6I5UBP=!DhFs##kD~rRckUeeky7Ph+sDC=Xzz5NBDxf0^~9Eo`s-l| zjX3yBl~8>D<{_vQv4GUojYrjY*I8Iup)?|NrVeKzMzxGKPEOmFdHT_SadnDDXMXe% zu2#p>}me7(1^A2GoULt-6=iXdfIiC)a%cpC{ z`fT|ALZH_b1z9vIM#p2mX@`?V?%%l_6#yfni62H@<-S#->^JFAxohTm)EwZpgbmzH#(q1rrM=QyQIxU z^)wI+#bpRAy)#J-zX(q!WRE?s>!mIU7*r!$PqVM|$VNqT>%C>?AJ;$a)B|`zs*a1h zh{uF`uXJXaI#g4VeR`TwmG|biFAgWmwP!OWyf-!;U3e79}VT(7e5``C9T3XJe zg~k(D*wwKyNC^p5H8g~Vgk(Srz_2!TU$k%`a4=OC?sFH^*!SLQH-Bj=Trtk3CU40yyZ`t~`wv5olKIf5C%5%j z4F&C*{ZglNiPMT-ycG|*NA4OXo;&+nS#@H)c-TF%@!65i^qKYTek6X+8t2;jO92it z+(yTOk!r3EPH_wlpq_m1vS@=xU6Q4WFYqF|E9{X8OO~BRvif_Rm}% z+0nL|{dxDxmXYi6qTe59dX9bB_r)>(EufaUL>u0U_$gfBNIdi^!jv=wes-3G2((%v zq#2Q}L|3rc;L1e&_1qa_$Bm3A=cCY!qa%X?UWFxR8nqkN36_PElfunrJ=A!-R#4eG z&bJUP5QKoXaw8}cbr4qu&&Tvm4UOrPj25B01xn?^0InwhL#w>O?@sHi2cEfB-W=vm0kBD;#u8ApQyIJMxCvZ z<=hDM{sI$5&`q18o?Z4$d1A+wtFhhlf~qsw5+w^~^qb?Fve{*K0{`xjS!3|oTv@s? zYP$To>skR3F|kWhxAX*Xi zGa5&!x`uqejd<^ShRG?ZLY}#kooSQf13l)G$xxa&7*HRmR?n(R5KDo&P^nN>kp;?z zcMe4xV07gdy<^6oLDPWl^>Hkk4J}*!1`K{qgk#)=9=bg9;0~ z*A{oEbW63lk_#~$1$P05m3KEDc6nz_Bj%X*UY(>*2y#=T-|U-6n4Gjom`T~Y=Ubar z|2xxePrEOvTgioH-pOvvL&6(XBNF13zN%L?`&s(Lm?SVh|6z5E)c*brwR3)1&y{We z_HC~+h#Z4YFjMlAz@xbv@fxZR{vtVwB1_@-`WER-Cjp;?JPI}0Dq}OnV0IfInrC8w zCe}d0fJPPJzT!mHuw;;i6cscxus}jj5N{NQctu?TG@=p+^ZJjzpB`%U>6~JdhGeuO;%c8u5T%&|3j2pt+BSx- zBj$6J!qiUp^N_&WaJcVNsCH}iOh?pM&Z264^K$a&V35IO@2fdh-+k8cYL{yU zelq*+O?QQjKtTeN`slL|YsD>FI@C>KD=~Gdt-zkb~5UOBschAfjH_gv_oZO z_g>sVA$NJGrA$02G4Vv!<&@(|FgIdR>z_DnnN*8T)FS10eQhlSC4Y&oS+#0p*k>YI z-=1iyygkP}>y2D0`B#Q|<|T=ZqUMiI=0shJl=JiKRgRpReH!h&`Hlqf=9Pr-2uTSd z*8h*3s}3=>FD7Ou70B-6X{~;D<1@S7{Iyp8+|%B2U)-+NmEA;EQf6&DEb?`OF+q=)B^)}EuUtQ4OL9|NB$Ai z+!Hf>9Fp-14br;wG=fe-gu6Q9obSlJTP?ampIr?1x zpCKbPsjwhjp|o7|@Nk766W<$!ej~b~2ZJ)>{DlioCnwRk%~}8$td3vKaKzUIzb8+A z=g|AT*aO=;nasflZ^OwCts5HXDx|yZdM1ND_#R)nlNCx&Szw$H$$wFHVj&~-70FSj zuCa`2vtV-eR4RNGfxoe;Bj?}?E2hw7x5z-YYr2aB&$pCFsAm7i1 zV58D;H6;zok%8768+e)kq2R_n@-?^p=DzEdck-8}WzO{JtRYIYuDRI-YJW@4^`<*# zQjRN$9_nv;YjMKxrfKAg+Xuuorm%LLJ67~VFbL!MTvdHTyr7LuU&pvr@xj{15ppFy zdumtfYbO(wMo~EIA9~bM8mu`xUL*dFQJp_srL%(0%U&G*7^yM#Syl-mq_t8~lRYpQ ziO9e$fLKq2o|-MygUAxMt=3tMSo%DqMGO5^rGV|Z5h(r8lH&^y$;0SqQ{jz5Mc6S^ zeY~I)L#KoGQD<^!R-pD+unJ4vq~=n`=DeO$(*dz!880J=lw%>>aXWY0itKX6k?w-B z*=j)-@*M%6HEtE}i{`5Q+5PQ0)zO(JJPBNnsez=8_32KFy`s-^j;LrwaSFTmIU`el zS({4CE|r_Ud&azmcbrhJEH9mWv-wM)g}@NV0Fy8Ul9H2+aEyR}K&-^8^mMGp+EIpV1f!&Z z;Fke1iOa_-N$O2Z5GUgN3z*wJ?rZ&3u(@R{T3d^_uKn%Dx3()xq(dtDJ;#==xZOSY zwWC~|NSaHe&?#y8pPAzVeZ}IV0)3J?j?OJpa|37i7kPFYj-J$JOxUJWplF!hJr}NL zFDrT|o|!XJ`u_DRt1D(B`PXBwx`YpLh8yF0f5a45tub+99zWdIVV<+=wWq(so}(F? zxh{!CG6{j`<&^wf6Ve@ zV|S9|w|>;3YbKBKu`Z-7ko&MAG`AwXt8hwCrBXMcw%?NA&>gpn(%unJ{J~b4GokeO zk)L5p+7E^U=UrgC-XL=z(E3DP#UI{US^eM7vMh`X2IP)0N?9iB(dSSqjHa_f)O#t?iZ0HCXtBw4;X_UUrsw z`3^Zub-8=*z9xwaeVBOVuG`$Aa_O2!nV6F z3QH&s|H0ET`#9Ei)OBF>R_pgK{K7T$r}hhe7JilH(%!VAqb@i-uTqJ#48@nr=ldv6 zzD>IAC;U0BUsM0ZNr{?QrfEd0Zc$)5fA5;z%jR?b*;?m3T`4i&kMQyRgb4l(P4ui^ zySm>`L}AIlj41riF^SvocgI@u<2GX=Q3WabD0RrMljb|aVXs_%R*awD;-edQh(j({ zoo)%pd=E{Zk`^8=saM4Nt6`l(*!gQ}hPfP_JfFR1BlxHpyni>;H_>~kd;0JH0H}^D AKL7v# literal 0 HcmV?d00001 diff --git a/graphcast/docs/local_runtime_url.png b/graphcast/docs/local_runtime_url.png new file mode 100644 index 0000000000000000000000000000000000000000..5d5ffef7f8b83426af9d0eedf6e80eef9e011ba2 GIT binary patch literal 792222 zcmZ^Lbx>7N+b`WMC7lw|ARyh{Ez%(%(%sVCAky7PiGXy6bV*5fch_C#{l52)J9n-# zjK{;~u=k4R`PDL9MM)YRg$M--3JP6TMnVk=3PA)43Pv6Y0ld>5r0WL$f^!m+)j$G& zyphbppr9zAWF;NY-vB~(-=#%+0xEO{ci zT0ayOek*ENoVgUmFu${At1K)H-)I=dL0m6-6HxRXE285)$*A{rOVQX#tjl#mcL5B> z()Gm|FQ)lP>)(5YgP+&Ft#s^W16Z=D9Oi$L7}Sdt{+}1aFO*5K&Cvh9Z-XCrNEaj) z*>v^$?Chtv4 z{#Ghz#Ol{f9Y_=s?>`m%&z;O?yNfuiwlA7O-G@5K&bm10r}XtnP400$Jw1Idx?9sb zG{hadaAn{vFN?r3Y1fyQ3*YfTqEtzWJ(f>PN*Eng;!>oJm&qyHmv%kkI3XXZ9e2qL z2|=j0U5eP?wOklAd>Kl*(=IgYSz7-7HfXNcr>Tzh#F9b_*F5&&vNDTUEKIrQxv5#m zM{VK8ps^D5)~Q-JyOW=)u*YTTCgb3LhizDTXw^^zC+JO>66bij#>#HZ(oNt$C;Z?B z(w>+V2Z6X_vPvqE!?6`zG8x1mw6vpaKjoFyO%Y;!ESLf~F4yBBmKf|~*B333T{NN; z!TKrlmm5FjSyf??MFE8EM<24EJVy)0St={0Of)i!!dmpOJGT_3&!YB-3pdAwPX~v4 z&1EFKeHV+=TMs&34_zHsSg;#s1}ph)GsT#tK9grv1bXs5Csz$q?Caubi$@MmzsRCi zjw^pXk#)KrxG4KDl5X5SprBpKn!XT6A|6K@9kpg(C%;y(TG>eCw4DqYr;i?h+t4=+ zy__v?ou_SWYMlD%ZU|;0A_8Sp=;2EY-#-2Q{e9zF2=VRW%Ci(b?t&k|J%N;5xK$c` z1c%J98ZP`|y@S46J2G2SWxkc}>eh=f3O{W@lCz78!^peYxRFmDk7xb!)yA+*&y*hy z{~Y4=n#4P7^qmE>dixu9htzmpefB+zG=vK&=P-Xxo)uQ!OkxVE5M|tQTu2v`><5G*W`@aBc&);J4tDQ37;9rxkReIt%raQIpnTTr+&i~#@R(a9 zK{{HG3`gw3BaBA*{3{?$<wd$Z=pq)jdLiS9wO+P^i14K#z?nuV zHneUlOVW~fx>`iOH0SHo#yZ}$_b_7L`9swVgXIj07-F&VdcBv<5k}X)JCjKS%~!kB z8cA~S%#`+1WJ)PL^Xu!s9bb^{3LVPTRIG99KH~|5FTSneY36-RkbB~(NFO0$tv0a_ zC#SJm=fRsvrc~9iBKWHiHUf)~*|&Jf7n5h>$a6_H2g)#-6?k`MAXloN{yBo3Nu>IHAaeyl_#YN`1OL^RMz{2?U?adt@)Wy!2 z#r5HQg<*SEx@{A({;aTbXYYl8<8P9kOgA!pef=-f5H`a^V}iQDlN? z;+TR-l;R;+k-bp%B%u~r9wGf?USFz?U-df>a(+USPhsh0i!!Rx!VUiWSIhTtgM4aY zf}Q?zmy3Q{S{m<8hLfeOEtwhbjX>-{ReP4(vd4}or`|&SNNy}ZW=~WUT?y`N&vZFVv8+0$XMRgdXs45=gjx>_ubmx)~ zjp-^U3hxt{4Z)Q-tQRY{Y)8ncDBsz8%B~Hrq{dUhTxR;mk=z@9fx8r&p6j###I9E~ z+aY13#(D3Mlj1stZmG;~^gSG3K1y)4Ewb$e?4#gZi1);BjT31Et+gM!$~l%tIzMrQ zDjdEE=(z|!*hHr63MF)ZFME)3ZzQ}gH;Ld$$uUCVkP$vnuzoG)Nj4ns)hk8V%sT*Y zmrf&~Mdu)fbRmMCE3(oJf3LNTJ@Is5YWRsJz6K%O<4ta$OA-n+iaUJuitFB;OUyq$ zR8@w6D~d=Q!CW;5x$gtGr0)uE4K7(4Gs%AMka`}Etj{^F1bHKGQgV;b2=#Z6lv0|s z@wcV~cIl>Pt!&cQH-`n7MKl_%U(rwR&UisMW+sdp@FMO4J z9-Ss5K`3(cY-jMWI=&#L88*Lxc);P;9+<@23amaAcQZXnQ-wP`?}a_z4!0q< z;h6FFD07wavPAu@$O(G%8ov(}zu%+UOYQ|#i6#uDXD1|VjAJ5{J}eFj^PiBc?JvnO zy9-`PyV(O*yBg6d)lxqG0D~&E$58p32oj}Y{<-hQfkkg=l(@n{qa)G`!r;JEN|$zx zmh?2oE?yzs%w=s5+nP7dn^L@dvqpR+VDCj6-2UGAMHz#C=h!LRjBI#~F?^nqN*yf; z#qVe^pvNk4)1YqeMW2lSJ#rRy2BoPPJZ=6!`O(e7gQs?d2Myjx%koDh+HU5=5 z3{6c<@>C?e*nS55$nQnJrK`4;`M?gAE98AV?FMZFKr(Kbx7TA)dGwdp*9g|t#l@k^ z%gcU$yjm4fpJchyx4xUktUX?gVP=sR`luuM7g_!DxmohEfwHo+Ty5Nlp~eY^^ZJ!X z-*UH|4*5>quV3g3#iad|$#U%0f0_ECUcZu{Pol#A#%efve{~?N1uNs^#Obs%TG-UY zx0{F+aDDA++qCxvG+m(@1gz6ccnwYSp>NEOhZe#QDrgiwK0Z;zJTb471>#c-eN!P@ z8_FD3=6^uf1+vl2N*=DL>o3V0oA=|0E!pZO_V;am9K_cd{-Rzo}B*N0Gj&5kK(j-lvoqTA~al)XPg+WWYzU+nnuSr zqr-7})M=F9o-cu$$i`9FWha`4{%08>!(YRWL$~L#h|1KTdI`5>@b*CG^iL{DM!7d8 z(sjwG^!jP`qKEc3Vj@S?M=NPSfs_*7w(7*Esz`*qNR3@W{D;QwqX)w5vLbHJgO#CF z0x!41%xmmAgkFO9W|}T2X40iiD_k!QeEvK=_Yp!1*FG^VM$y-aX=s}u0ucvEE5wn z?Y&3IzF3L-&8O?fTL;l{RV3xVpnkpd@fb@4-SOutbKJh9p3|@(zTUhs3b3s`f^78 z(wV3%bXQY!jw3ccH#dOkcPoC}d?KVbD>>`J(t4?;tE+omE<#f%p^EFQFFkt|`wvM~ zF8F-R0I9XLH8U8M!11h?;8nQ9cOE;k%*;%}zUR6>ewSR_+>T>{H{(-N;^yYG2Y-KT z{F*g*JNv+J!Kzp~+PB^tnJVbT4L;$Y+lmhW0%T(sVG>jo2CZpjWo4~=?N(_mkLN>r zj)SzIMAhhQmgL8noUQjZM~`!&1vaesrb9NRhK7dGBDZN@)LO;Q(Y*S4ZkfR=2W*t! z+XKUwKi~;ttkQ5JC1Zr{EtU9>>N;}nW&JRAn_|k!$eZGeoS)S`|546UE%_g^bf;}L zrdGm;H6x;Ps7s;N;UwruHZa*PCpiIs_WsS9EYV-7lq|3FHX{pPRmKqBvf$%x>iU%$rz2Ex+xvM zi*6NnStuZ8CV5D|2(u_>wZ9 zXM(Z<9YZTU(w>q-@}}1p6M(>bzm)#@?cABSG>Ca(RH0Tuv-<_w_z1Cg552I8+Oj01 zV%93vAN4)8HwPrO{HVbO?oj5}mI%l7{TrJ7JiT#%0lR zP1)gh;YvJ6u~T76+`B$=^AAKAufxXBu=!-f@hMv;sh-TBi^s&v)nu{`tCHui5$;mZ z?#Y-tVnP(gVypAQginlX4w{4u@A(&BzRtjjk(MkaT%ZexjHx zCL{dg9}Al3Bs>ZZ(<4%P^_-1aW?h4I`l+=RCBi}`DI3~r5y$Ep16`w-@$oV|8l2p( z9yt+Og+^{iUz`(T{8}e@Wf7|j+PM{L4rJ2R9=QBf5bn+I`na4Qk#d(RpmSk0VQ<`t zRIb}4`y4VT)?vMK*E^D)gsvF8V1Z zHDP6Cgdx}kcz-_q{9Yt0DcP$;R#900I>g_YkpGvsoNSE_T4iC`=TlubH=eSx1cf)U zg=&^*ef2hrC1qtOC_j()GgiIoil(eMQqjuF)b#aRhsM6$*ZDwePH|0eMT_NjfJ+0$ z{!UM0VT6J=lHY1iE-pp_oU`FD%6}pND%`SDrJPSaQ>4P1K=Tg({!qma*ajyhV0T>z zd|504@cs+8JXJgeS?r5CDqFox;-HyK@f6hb%*>H%*7==xT(Pfo?BXQflndz;7q-}{ zhk*(#1)wH!g`lX-)hlXiYsaCKYmgghAC%}YO|Gmc3uX6njTn(0m-z&XQL<^x*DG3B z(8X+Eg-eu_l+2l$g>3K^OzjzwfNF>o+TDD~RL_JuR2#r03~C=gRJcSyL`1|*)E5R^ zl*^kNb$VQI5l!Nte*F)w?WRgoLWG!&alm9O1#V_cTa(t6rs|BAW?U> ztopBQ{e*+Njq@gUz3NA5nE(_T*tr0;mNA(M7v<xENScj>8==yz( zc&-Z4^%m)(8WVnvO}agA%h1EJ0|{Vlpo~L-THIo>-=X%W`S{ot3h<&qGsoBs?S*bn zPtR(m1D_wElk7`NOLO(nR%zf)myXwH!N^us#pAp=n4sIZMR|F7xKw`U9rF;%^%Ln>>ex;o(W|mDs-T>oKNIP;UadQI$=kB^>8=urTf z{O2;O-`NW0s6PN6b@}E!Nk3)!$mHxSq6ioaqUd7>Tlu4tlM`tWH%cT^2^Hkx<8O(ffqzw8Y|G6a|N^+#w_YuZ5YUjyV zTrA{XfWfjViA>7I*Yzs zNFFL4D;Yd^1tWfzqXeDfl!nymZZGhk4VLiW+?5`>H)5nBDmH(YNI25D^_PgMW{Fsi z@%@{XSl+z=tzwfH4QSb8b3n_s9WJu6on8nqbmeI8O0hg!0~QK87BS?4113sxgW#2# zrEge&>X5Qn8vSaPkvEA7%P+)@!e(4o_L}{)_my2&(poJ}Qr6QYz6}v|ATG!TW=An) zTqW}q7rpg^s0NZ*x&%8O7=dXNdp<$W-`_xa<9CL4CV~`MBqP7YNe+ zjX6DdIN9722O1Fvq#T$8XJUcT0k!%#{8KdCQzAide~B^nEST1RDqenDy8KPSZeh6_ zjWU@~kj*sd=LY=&{tLX@N3wiTN{%HF^1M76c!LKc*s!_|!ZMtfig2P=n+DZ`PrXw6 z{DjO-34nxq=A%U9?D)g%;dzwc>9D@vy$bjP~Wi=av!hp5On>NI~Ki1B)!+E5oKU$%I?k?k{Dgk9sMkX{EVRvV%Z_l`7-Q9$c* zzLUqdE_saaD*sm3T`fw(`;nq&!lM4}#p2u7dj<)c=Fr013SSBuP&_qLX?*=PFS$)C zl^^k5uA6Z01#7u{mj~Gri_G>b%3p-o%4~B7R5G<*W#wJ}Z)!-S71;0R3SRsJ?Ai%S z)>$5CJcA?HWu>dMPnp7>Lgb%mi9fr~%20@+GZ1!btjuwrrY#+qJau;xR$`|W=FiZ} zd{oo1Q@>3QOjAK3Em`}J(Di1K;HkHH2e+)FVphczrg^CMc$o#JAATNPDl)?3JNepH z5pAyv#$SBmc$~n=5-Vi9d{=&U>3gf!napIlXGGhgzL)XC@{R{g6>k6Bvh6+CLai+| zQ8<|6#d5S5gD@s*vKKm`s3@i@azpSUY%pxgX#}Lit+lY4I_^W|*Wgjp5Qslt7DO22 z-#>W|^;%A{VQOF-`_Ix2KGr~8t+%S@`5;ZuR~}4`kBgX^ZsfGA(~|TXFEzRFdo<6^ z&Q8tE$w*5her&KMz04QN9^>?VbO%Mv#>PgJFV>7h7GUQ)@2ui1&j>S)!GFJ&!R{84%={e98IgJv8wILM!^t>3SHg)J#BNBui8oXS}WUPxRz^NoNm zjes`_AOZ8odN$jCyxU{h%IfN`z#MFq7H63}-CqU%jX!A8(AGvgcpUS4=FFFl{j}B< zOp_=J8Kfy?VMXx#S65eo+uCx2MT+U59&ZJ`QZniZ!`{5`Fz~t3d_D~oj*}}4Sb$$$ z8ABwohe}RP=0|o|thdSk@x!L!pud(0xEp`gLdi$xYD`PZ%1pqw1TKa-zsyzphlGR} zDd^3DzR2NmZszIb^#SOP;vr~jEf=Ht;6Y^dmljD<#Tx-R2fT&gv*=4Yk|mg~nlA-i z?fc}ZRiPtlZob7bC0ICR1)4YS^_=PZ-JHOHr`?)$27w|QVU-3H+!u;WrF?pnvN8nK)il@9$^E%-7oxpSfq4#=H-~rk08uG`k(K^L^FQ{#Ndwq^&K>G*<8%ECw&sVm)L% z0n)6uI#(?So*OWCAj1uSwJ*nn58=-nU`?}5yK-7X z9+CYvL?mamUi4WLaqMF17|LBHf|!Jg(>aOW_nuBWPlMi<-Hqm2D)x~N(dypjN-zB9 z1nXQy4}5Ms(QEHc#R>lwa(0u}JrTH-d>-V#P8iFnWV*=_X;^{gipb6=CfS@m`SR(~ z*Q}BSl-Q1|sOxP=%{W0UDsSPIxn*bVIL)$a!xYf4g&i=<(xBe4P+3w49Z}$Ud^(ye93{layE)v)1 zhvQirU!0O60g^eZq@!|P4_p_{+>c2^Ok=E&6OSWTS-|$Frnn?5JBAh$2uF^znlAn9fmntpu>8(`*|p}Q{xGa&()icaquLvHiUs(0+L9B*co{#)1Y@@u zT8ya65GOBe=*r4VEYW>CZy%*1yQVPm%#DuG=-ZA}+(-HdD@~@Y=&|8d*Nj__CSSsW zr*`7K{)!bq;T5dUqC3d%-Cdi9CO`9*CZkMVs1MGk>q*GE&%*bjL&!M0;c5P{Sw_dL zC>bMuknAw$MA_3>a%InmEXvrT^AlgFsfRDFjuMWL7~LnqGOE9pDG~6L zm%8xg6cLn9?UE>df3VUD)tY%@(p~qOIF+~O0}TBT;v|t2RZ8z0kURtBUQQyNmol%= z`&j5Ba@H@5*E!Lqg96wPP-bpC!Lr%?5HJ1tnfDJ#k1_|)+&9OvlU9+Nj)!x}tFY3o z@``Z%PQJl9-N-FhA#UB3LXQ{|A<A+-+nm4qZx$ev z3}DHDAN8tb@M~v85e#uKEfq8kqeVNT8HT$C@M|Ky|JZ8g`G72;)K25r7=o|wjB3&TJQn!fkbf#J^OlrrXL=%nSh9%Qkk{M=ZDa`A8$84d{9&zHq_ zQ`yfw(LBpCetv!wA}-F()nDd05(a5vh3@Hhvi*d9{rUwdAAZ1^F8=i^dH$z`6L4%f zop)8xx&b|Hd%mi=$R~daBY(N=udAyYpPA|Xy^}QrT(AM_s@6VGxZU?li(x1$L9wNk z|7JfSNq`;M+j#Z&0H|hp%l>f(=Gha0DIzXErPJ0(=$rWKt=s$k*u=0}i z(XJJ5sFn>a5J02T)G_anDxV{3e-WPI?ueegq|`9h?OR2M`3~41TeU+h`M-as65F-JgNcLxdgAx;Fr3Ms(ogP};WYZ5 zE^PPrQT-w>nBzbsuEqM_>k)4oZO zf4W^;ssvRhSW=K&h)}H$clwB97c!mw(XZKB<8B1s+ivpnm*fJg)`v52<^5BmO0feL6w7+C7+BQc6m4qg2fZ<20$Hle- zcoH_>$J-obun82?xgDOaW-nF)F%f(ij*=tsnb*5eS<(JB9yICMcfx}i^lmt^1}F!V z4FO240kx#l0jN6y4~IYX{GP5L{=H1ki>d|_K2S-#clcL~f!_}S?!RV6gY4L>pCG@KBdtaMph^O^G z1Pw4{?&6X%P~Ox^9$TT(=H~3S-R>Sm+Zz%g&WydRc-i@{dm>sV2IBG;b1opGPb%wN z3TjZak0L1My|a4osSLvC9l=Ov-S>x4iM4}zi(H$Dl~v|{DJdb0p3-rUsPlu{KvkN! z4C^UMcAIR))KD_+Rt$+bmjczu&gJZ|^{oPB{`^$aBLR`eEb+&H1X!nn5a-AG zK>=Up@FbCN)B?nLP5ivSv2Rp^DO@~)ZV%(K*ZUBDOy1yN3OplDvIk>wspR=4MJ3<{ zN!=m9Fti$xNsyXlzfE`+_kVQ4=Xf&_a{S&Vd_1>txu|jUI+C{kQA3OawwAn0dKbv( zG9TCLJO-Xg^Mv(O(s9MQU2vyj>n9JqYj|%qF@A3fj(>(11CqwCxJ4Q&=u46HX{;Fw zY1IRGnQ#LT+=1s{+Vx?QJ|TVoLZS$_7g&J$Inz+M_pMxu&`Mk#Yw7vsr;^`Xa<^FM zR{QRrajNK5(bf$6;fnRCNXHS?DhkP&DMD^{aL9uCcfaRrB!y8tJc{^`*co98 zvC5gmWJ?VnUD;odMN}H0O|cPaBqilsybRr6WgsluGOmONJFR2L zLuN4+VM+}qpCJfe-Kx2XARfo}C=12^Zv~DbLtE!Rn5xwVZF!VSK3L|$Nnd&4zx>Bu z(gBCka=jVv6j6Ir7|i1qijs(|A8zEx4nTWD71(gQ{7l0AMQ2qMvn?!H;fqCtRd_-Qg%YhakUki0s{sXwN{ehP_Q^m>_{~8?t zy8(%xa0*&UdR!DW6KK)0JT?AcEdS~c9Q$TJKUB6}>$aUm+vQjhfhqtlM(c)yfPI3)WT8zA@MKyb8P|WShA^4(=&1 zua2hnz`cTC#LbGIFkQAcM_JkA{Jbo91E7|(jee(9e?%ON(BHS~k+7zkfUklE27C)y zy3C*=D@gL;sH&=pjx6>k81*dgb6JpT00wG?$9AfoLoXhnzYtO0jDtQxOj8q&F3X(? zh!uLicUzB7Cw{;^5Xo3t{@7Mqi!;SlT3QP1l*6PxhLj=pPs4=Ym_OnZD)I621CSz` z4RirX7ngS+P(=wo0Sd?^Ko194;x*?dV2A_?X=-Vyef%gdDG3jvR|01}ul#qPeL)Ko z#Z}=*H9MTEdE8U>GXb3#Vkd+71WFEpx9{cd#5B;ztwQC*9+BxXphO@h2N_U?!U29Y zN_zu3J2hD>b&%1;_K2S6ULoWQV6H>#v?73wI>00V_hwz*@P1Ng>!!gb9YixAnG7%& zp!XxGZA<2>7Aa-^sW$G$5<3OZn0_u62u~o7wQOa`QfLBR=T%m+oR+Y9OlE^2gJ4`F zw0w}Tdj8p176F8-&XN> zrL&ux(~8fHs=9ih%_8Vk7hssP?=E&WJRp%YGmccFVd)qWS&r1r^p(d?_I4n3;&8W> za#6_Y!|=U$vP}Re7c6bJCIAq?LX&z^J`Qt^ZI+c+QpL-lEl?8?5%C_?O<8{VvKdAl zlT%m-&p-?0rL^rFpQa>6hlg3YY}gex9biI$Ck|L&DP(ukC+KD)1dy}H3gM<#TGFZ& z2J!%4*6aa^o2xL5cmX7k*HKW_~|DU z*SmKUlapPiYu&22!4Q}S<`EzzX>uPf$Oe1$Mj;1?O|!G1NeN&gfNt6Ne18xwK?NWT zH30s=vKj*h%kA<_TmN{jxC2QjaE1iNAd)2n-1^6RK3tn~ysm88XKZ`iVpRvCkDZ-e zb&OOLVJ9>qi3pH-q)AX_@mj;?<;QG1ChF~2vpl%ZeC9;9CWJpa5IKc2?Fn< zw5S7;ct6Wfbiux#6NjCZJ+CA@xVLIt?++Shgg3zEP}U7=^2bDTpOh|6wY)-6Dwh9b zj?GF3*rvi$!z&tP;#1$ZRx*nYj+RXGiKtd zMS>7XGb0plRD(jtdZ&tdPu(|8!tg?5J#wOu zople63~A_Qe>Swd(Monr(ZW7D6fWv~#i3dc#wBdUqOU4Bl!ji+Q(LRWWk ziJAt(X;+(ebHE<3GO3@@N{^mh}1Zz9#!05i#y~ z8k}yl+%vGQ5YL?-c@LyLE9xnb|E}nr+kHgRG>TGT-`DV~Iu|L${od^?oVh405Ha>K z$Dc{2E7Y>5IGuv#j(~KbQsJkSDEZNfy!vB$AR8s$d`HIWtPr;9?r-7#ZS{N02RS0^ zd_(PCUKLTxfMF-eCcnh|4e9qy$D{Fn^Vd@ta#(6uL_hMiP%>;|yWXS&G}EyAsOy;TB2yjPIv;4x&apSJL|u2>oyD{ zl7mQ7prV5mAhi2dvbH{a3Zh`X~PVQvgBB5Mo;+5SN6oM}gbbz`6Pnh}FblE~k=B z2~01GQ4+OMT;N4{KP;{km6nDV6}_=-KH>&W+2ie*LMGq464j~c=^mgM1$woDu>?!v z@LK4b+T_yG$ZLy?&LHGhP&xow7H9kNMMpy4U-y} zA|Z?yG`DK|H8Dt5{LA7efR_}qg#@wKgpmNE0-rs;xCntMii(O*Xyif}5>(iiZUAR0 zq;p4uvlM{sdR)zDLj$HQ)KB7s2Bj?g#1%=v`ssXg(EG^R&=fdG8k(ByHVa%JRuOOp z3$ZrX(E?o#{=!?IKr*Kd?N7{NND$aDU8)W;4y}d2AOn^R8?bGF7Xis$prfOMUJ7aa zKOHutAYK*-wFdgQKd`SXW$7&t$epA}EX%X}hMVG=Z*-uBsB%E6f`}`I&*v!sV{n4c zg6I{%CIZ0)aL5MokAZmVy`!Topef*?;xpyc)Wk!y86Z(as?hg!{kp@pVZg|M7hE*8Xr1YXZ>(FYWRm#>Q!efVz~VLo`WRRcHLgbsMx=__LYXTs z&o^KXfWut8FOL_%4|N)^2TKy5S-ZBoH^6g+oOD@>;v8KCl>+m+=)WhJ=pvtppwDxE z24+$GOc?j)E9ImP``%%DvFo=m=Fb|Gy&y;UHL3A+_(HzQm7(go^lPNd(`U(m?~fq#B~GF7f{inR?>LepLNh(i48Sr!}YeJ67;q6jn$;F%GHql09LPee4z$kf1d*pI)94Bml?5D9Pjs50+6B<`m9*)|qiuqU=+;Auv5 z0&csLQBLEQ;aA0xFU1Us=gW8TXp~Fcm$AeOFPd0q;AgXIapS089IiT81BzSLfg+&6|dzQt#eK`g}#vqJ|059cFh)t9T+49%oCT+4NG2vJVU2idd6EOMtoQzz`VM&H5 z@=TQZ{|n1KKIGIS&37LU7PA5N0{iYz}~;Jw05Y%T0oEF7M^ zn8X{x9OOd`0#@oLpF9a{6ZOYm`8~mQbsQa5E_;u%-yTt~o24dOMQ9%DT%;nUzNGo< z(uQlk9ysGWCJArtp}q)`e6d4yXUEKaz{XsIRjiY0FH&0B@VFB+?lm{ zZ=WM23ylPe;I0UZEs>PBOL1WqU->eq>OFskUc!xq6IN|_PLlQo8!Em0%fX9w z8%cP9A+i23tX(t^^O~Q9w85-qfkcAjoK-hRrVL+FMJ-^LH?gbS)rdk8=^slplI0KkRfrW48YaYvFk>SUv;A9P5 zioWY>V3UH=Mj+eEpXGD?0Tg06z#KR~4X*lg&i}jxdEHPdF{=hc{VRhw%`fw-0mOaqi*sC4rtZk^#NW^qbPhAgeZ~4)!NUoQ9E8U~=}&_t zMwc`m&E983O<-SDYc~Ild2oO0T0LupT>3_*w}>#ocscu8_qn_2sC4IF^UcLlIyf5I z&%a#D1VRbG0xxW6n5EQqJ6hAo(Awjw(1^$F-3_V zx#swL?w56nsde(7b&$z%lG&4>a`<<~t(dsJZcxb7x6ejC@3g0%F1*&;Sk}>TZkNh1HgX$z>S>&v*7*FGsG1H$7(Vv zl3H2>fCP|XExGRmw6M&_uSbgwPA^Y4{3g~PKQ5*$NHZ)JC(TTOna=3<1!#*Z$Ou5J zFH22VL6S4mJCK#RzyBGi6+l?o0}J20KiNHPT9rN zgfA@UR<=KT26R$*cz8fgw&dsJd@ZkPO9LOXN`EH|u=As@FzD%#qLa7xFq&E(z zhntff0IPv2Gg`W@syS1Wbg;LV+ulyz4$>)PZcJ&{Co9Ufww0Wt-Q^u^F)Fyh;Qs6F zRz^X5>NUSB6VM=xM$+EJ=ud$WgSx%lh@EfudtpC%^a8f*>*=D+BS@KskW==Y25(@h*~5QWCSW zl(e+)0fFi266=d0MeP8_&yZW&1*zGGzMTt)gC7m+uFZE~B?=e#44O^;7XDr|sPDGa z3m}>Q&=$z2#)DIf;lSpClp#*HV?D@x;eh0BW#t=T;cO7w@V!W1IR!_h5>it7Bkoz&}Xg*;ty%c+dYSjA0nqLYO2ND(Xv19b1-}MGrxsl7q%NBgv`%gHf zIxw&_@3&fM))-Y}AAS&w;lg8wKHt2mTbfZ`7PPu%H@xF5K7z*zzLfVd)7zIk7x7LP z?a%H~aj&}_d{9~k{wAJZ08T#!M)nuA;IMMeuV!+ z81>DuCKgxe_Y>0T$OIDe)q0rEN$;Kqz0|cn zw7{eYq0;C1!oGg^Gkt_mo~wnwHck3h=gA*eWsLS`QqeUt^Dy6K9L!1wXT$aK$=99A^HTGRkr8h5S?ML`>dHcsO@47ttjM1uV*trs~sPA>87BR&t4%VP$);N#yRWqOcgyF;P)c z_`FPmoIjtpQ^-5mjo`R9fFWar<-saz?@#Pj8emZ?<*tcSzSzaT9GvjOs17BM--XZ~ zNEP-33tCQ6*})`9r_7UI9CW^_z-dj}hQ8Go?|C&kK7W*mfNLAxV;k7k-piiw#xeeh z$*Q_J5MH+?$}1oM53lFIdDsW0e1zurirCw4$V%7aRuW&@e~Ml(I)@#-llee(3?@#@8sPA%Xi*Bmqi3?n`~x342aQB|@SjrLPKI%}`HzVNT&&8LMkvZ_U}7 zu|Qm8u17mR%z@frz`otV1|^~5rlt-IM8pLAgX4{IR#g8+F35C66Z3$r`{C!O=vOg( zcFO_=zIWyz!q#}&g$i>2u!tDdwo81Wr*iZ9g&Z7epaKD@8e2=_n>jWT9LdQII0Fpm z%gSvZ4@p8f@W;WSL`n&3ISL}BpqZJOvvF}uaNWT0@S|Mc$iWrR%!hylE=7T5v^|^( zF^1g*$e4y4%Usm1nHY%UHQEmj4m=KP`ZG?C5)yF00W5H1=o|g9bhs$0Iy!*zNw?L* z$S{`h3Ja4rERKLSIWRivy3OFSz^6$WpO_d(BReW!Zf<^-dEy-k_$`DoLyl*4;A=t- zR2z6~B~!xX(8wyqK(5TP|E&iHuz@O8zPr1Fqz{3s1YjEjGjs2eD}V8? z&SgmYuXE#U_2Q1?pzPZ>e8_~&9`tUUp&=hW3&)aZhiX*6$00ljl{vpIbUOn2z-{P*PJ0Qx3Mp))9T9RH^)l= zC!C*lW5U)~f2(~hfDQZ^4j>ibO6Gzr<1g!3s-arbr|m1xyJNR@c`&vhPzW9!>5h68e`yB@31l2iflvk!5!wO1Xn8np zK3r;s1dJ{03MxbGFWc5`9wxX(teie)k{-03xrETjntl3I3;x%KQ4l8tR>^lEA3nSG zr~e$db6vE#eewlI>vbAzGflT{VV8QRk6Eo_@=)MJ4xu;F3npd)~7NfoC6 zG?GrB0T8NjuPFT%lE9(7aZDMK1V2j_8wxgCK<8?lmm)8T>+!q5lM9B2-@vsIx?P0@ zKVsj!go}qa0kkEEMZx$r$^`d6ee>*H&x;(i#O~@Jt$~)w&>J&6=Uop$rLlbWbPaO* zW%llGirYo9K`fV49`IVRIX9VbOPJE^)he83STNYX4u?aH>1Jf*;0T%r9*YPDQNvr` zkSd3;dj*Nth10mdDRWoSiWkILb9+VD%v?|*YL{}aZ;kf z;Y-e|Zk!y{C=CCCh%wr)I{N(sd~BlZG_`)a{leyq0e8>+%BD>2RF9|AgT-6@9Y@<<{sk*`#pcd{&>dPBjiqRC1~E)yKjZhz zfGlm(%T^+{0Rzg8sCgO=6f?iESX{WDisM%e^n$THD(AG*am=_MRQ+W*OtNSQTwVIV z(oSnXKZt+bj1Tk8plMqWkrMk6cQ!hSH0-!hsU{=55;aOMpznsx%Yyf7p8O zcrN=leAwPQJF>T|C@XuHO~_u^A$xDx*|K+vY#GUjtSGB&QDlUW8R0pv@AvzAe$O9| zKknCk-|6G>`CQlgJkH}hj^h*_eCxdDF>(VxrD@vrOp@UVmbb;x8LR0tWkDN!F;13O z7SWi@Qh$OJ=-#d9HFRz&^JI<_=Xgq8W9yu+yBV+1bk06FOZKoioK=4HT7L7>9?mEs z9UQVFEK`03ZJ~d@{amLNZ85JktzJ>?-xS}`j8v5#h|lC-q2*Yqsju$TLKzMY@+>8~ zl%*$|p`qLDkL}D<_@e^h98z{8`V;&|B`ptGa}%$1%~i^voK5#{1l&rp3y<^3iEdW{ zuu}YU==6m{nfYq@g~0{Mn?`$-FLK#=P7&T;`-{Bq$p(y&t1BH_;S=>TswO)Nss=hU zQqbAG|5$V&P?{f{Ipbc7Op(063-RFh3XBOSmXwsNJUoUeZKLB?Zs=7sUfvy$i*s%VW?Kvt z*J40i!!znLSVFEH^7^)F>MF@b1TJ68ygH2CGi{3_6eG#Y@$AsU23@^Y9YR1OC%ER;nJVbws`^^sIl z2Th-o{E5cPfDSWdloVD6l@|)XTwz>0@H%<$(QrpYTjiLWfeD z`tgfz#o--khd>W_!$#{ut}d_r(#4Z1do(3oo+R59=Rk#im3#$Z(%{c%;~7gMa|Ky<~3CLV{F3DI@QN z12k7qPT)XPV3u`q+9E+U)z{}(LU&VO>mGV83O_Yl(yyVBPdC8YEtp`&WBcGQ;;%qv zIImwPAWX(qO?wgoZ9=Ysr)o@x5gr5%Vr}xiAs>Gj#~9XGpdu4)C?LE>4UsSsINq5b z_`TTf2ZGX>&Y*Lg{i`cBfG*IAX@FBeL_`Eps)mQ75iv-OU{aZ&$UyUlsNLZczuG%4 z|L_6^&w#3wVxnQ53mm+jJ{u2DJX9l3(+aRbxI*`X=ocFrgeE2?0F|^ZJ{DH@;J(sv zlFA@tk1TriYRHJA+kH2`fmgzhOHol#qdVwV(6T+I8TVFsef_v!7jV+hZs$vtOYz&1 zibAXO-_yU->~!fQ(JfR@v+GNINDdZzqyfj~*%6}-eeSz!xR%CKtD&nD8=okHhOc=u zdfZldDl^p3Tm9*|%m*x70<@TjD^IFTD&KKD*3|QxZ!c-tj_bpOXTCF|b?OS6^u=A= zpv2&FV|nF${fiuCq|B#uRpV8q-G-Pu8;9>G*xs349{wV((>&(D@GsoV`*9f{rz7in zXVaf6d^=03JxGY!l=6ijr+E*?Z2I77^Rs0WDrT;j;MB5%HXL%gd=AUd2hP}}_a4Vy z8vO9QWDXysQt2bxf5AG0O=zRe?3;|opP!-mvU4`?{Rz%1 z$XXffi@?E9oGDQXEOPE49#V=gXjwUwx)CKOL2g3yJ+D<+I$BFa>u*G?i}JDl&bxGb zOCmZ8gZF=opND?4A(<~LWe^@Z&q7gqk8ql$c?`)HH?bm@RFfqwmMb5) z|KuFxqtT3t3wRg)l?Q4=s-*plHOuz2rK(DU%u5R_-bj?7EZM_e)*s5r?rqoxo#E{d zncTZY>&047tc%1y*`2VA+meK6D-+$r$b380Iv8K2^FE|fDWD6Z!R`1pMg4SG1~ujH z$GEIU%;Odo*#lG$s-JfYyB-+{u;*6 zi&tj4TD%e~pmQde-u3;8e#XNs?_TL?6ZYj&Pen@gz#lb+``iT4FC1WOR1d<n_1gt(01qa~`NM!Hq z`_8m^a3H)agd~AlL&NVmN3$b;f-h5lfwB$>%fs9Qdwk{2(7~A#)ie!#>RG+y;M+jNC^?V4x4T%lD#$(@`V%6U1O4 zh#I8Bh*Aq^DT%4*fRo=pftOnc#SfyDf{K%zpZ_UO%mb1AJcsEQNa=jIl{*~+=|9*g zYHz^r41_k)tAK7^7aBIm$$(}Lre1xWV4R;k)_ePjf^FjDrA#$l75P<2_f=(kwSept+XVj!nPe|v+|6K5DSAb60ttP7R4l{0$4f`v@xKkM4!h`| zI_HMyYHQuIl}HzV#Dw+tJU*t|xH`o>g~}%?d=@wVnc|Lf^pR|+T9doU8 zi#q&kMx5&@$rn?j9(YNkwJrHHL=Q72gIwXxmPO!a{*xR} zz2rKXIJvX0W)o}eDUu{cRg-^x$vb)*&fkx(JBn4(-*_zRDUX)IwM|F3U5RxM$Et7x zAiS24O}aA;LG%N`KvW!ha#iMo_Ta!t5d}GtOGer0M!`LSbQ1$Mg+kL=Ce@;rho1kq z?yp)eQZ9IoJ+VzUp}GHEkDf5w0F`=~OHW&f_3$?SMaD+l}V}WgQK`6tD2;~XtyR6L@4ksPA1cKeIVpl z+1b7d786FXR*Vdz3?RWu8KB1zupC!nVU(F>p4=cX2a$+RiReqp$|Aw3AOxH?S{MqT z7LA%dJ0`^Lr1T-$4)qNaUVcNb1-gVE-{WDP$(Z3BUn4SD! zD*5*>5FxBUmnSQkCA8*#}AaK z&x|pJPF*A9BcDohP{xa`v<97fbL9EBo$D5lLB0ZpDr5wikbS(X5pxZcO`*A zx&U5*LNyhR9Lv#yyGVIJ@)K-?VlAlyf%Uh~Q36RQ1o5P2W=_Cx4Mc|c0IJo&P70*V zzh4Gj&~yBFzVa3Hf-r0N2%3%|VDmZo`4JT{;>!UJ43Xfhea;_-e68=R_w#VzQv*;R zu-lOZQp5~^jZis6;axClu))Q|#6$$K3-G5bJ)ef&?glKaK!>#Gj#)M;r-qYiJKJVX zSkPr5a4MZck^1rN8H_o^eD@d;)LkgeP&<9^(Ia-Slm7co9OAZJd^xVdO7?jU1kntZ z?+vxILPyt2!2u4ckw_9yF1Yv8*nzYalsm+Ta2cG`EZO%P;3k2FERAW}q^|bj)t0%X z<@e;Pir*Z)C$AvfWW_y_1WZ>?Rub=si>V15n%z1LkagrQ<89KBvv=etYG#q7^fVn1 zR)^mRqLL6+$dLDT&lC6m5->2h2z+y@hCyP9f*Y3nfn-Aa5KrR14;&RG0Ek6KIi(|1v-4El}dX@Ii47X#P^|;=0-k4D8=yCWt zE2fxn&#%TIBNSuCPWV$$sW0gq=6}Z>?LB&_=-G#$C1~ zOy(i&VdlQz{8g+4^F8H!l($S5g@KWu515}a>NvTrz6d6G7Db*+S0(5WN-k1LnMovv zrI{Fi%6{;&FT>#}b^E|P*KcMqtwi_me~AJzNB3+@S2#pRHPUtiZpxl_kSgTZp7>ZK zs_?ZZ`bVh#sf%?=2v?aayNRQKW~0jP+l)QSQw)R#?asBM2#<#f3d!CDS+8ABwHnw` zv3cK`C6HXY8IFk(WOJnUy!wces|6L{+ROoY%HAihl;CEn#sN83q1YY4yvbu70+TeG zRbkwerjrkQe^Yq}TYHr4Q?O@=>#MQy`4wY!Y4vPb!pK86ZqSdiQkWDm+$Oz1dq+k` zdP>>uM)fKthOu+BI}V@qnd#r}A6u&X19eTGUr;WN_HFp_Mys^c5^ID+ll%rQXR>JyG z-3(is%yY5=8rsb)-!VM$9}{JK>p$h3GR%Tgid?t=|9He)5#tQ2wq32Z58oc(;P zR1C-H$SOmovm_##%$wfXLhn5m?EQrHFR;edm19jTL(3RSZ}xZ7CeyT?*whio`g(=G zrj@=N;!BW|kYLkxO`a(w-t1Zpv#G*uz!%X}`ji^$tTBhWrK;sYf!4VAlIeVRU%c0h zN`vC-7#RWSz~?|Uqs(#=)ZJIy4{O3_9ma!QpYJteOcW8VF8tuQ2;_GRO9&~>9qbp4 z#Ofc%gbw3G^lO1R#Rk@{7_DPL0;_z6%K}Bvoj5qd7vP0)4Zl{5hq7wB?!cm05;O6b;BXV#9?FfqN(Z2 zKDGflroKLX?TNjx8W7LmzlRzL*x7e2+YYcMboWp&m(IVP08wNIFAs)#3Fzk|NmyP7 z>M-AnTb7y{a_SWjAn2{HrO~kHw-U55^4^w&3mzI8f?*l3x1NLNa2P z>=BVt!bkyL9Wl5gnnm&`IMax3NQ#CtUWm)4vI3O=yt^UqT5=I08hdY%3=&aq@7st5 z3s5p*Fatv(G0-TsMv#DX(2Y)v2nV5*AZ85EcC8}n`vK`oL74t()cx7}tHz%8{6%aD zCsO2&wx-cV52`A;8+bi7#+X2N0Wv_O&;jcK-ZYhm4{5E>r#YvQPi{}tq))=U8 zFE7p!CE-(I2I#-UfG+xPCI&rOGEtq!x(4DML|HF8g4a+o?7?M1vLiarf6GsM-FBJ# zU=F54V1@9ze|@d`3v!VAP+mYl!a_$R;q7cr4s{nyjakrvOoHPWn49aMg~RMFeJ1s> zaa3w}qWoSyZA)*7)0s%;5Lsx`& zK;hj>x|r1rInzbjU5IOF1=0V|@Ng*%k$^oCR$pnF)L>IHHiKt{<8mcNnfW4Z5J>^lBE08$-g6FlA(wy4Dk=^p?(x?H-^J0gN!4|V3F6SN zRR)b#eRv!=9p&XwK!L5phvo(Z7fexVNeHm^&o35W?on1=z7Ns`7zgsB?EM+$Zv;;8 z@<9434KR}cG>wocxRgIeS3U+th!Reu-Gl)?*j?bSBR3hiHdmgX^rWMR|0_T}z7T*n zvDqR*eonXNF)JAjvvAG}RT!(qm581sei0ejI{GuN%z#91Yl!_M=WAOfgBBf_L%46B z!Jx9B4&~4xh}MKH{XRPwNKrHjQ63r_!()yNGAA`85<%|@9i+PCje*@02^v}G^pWA( z_O|o7xhUNld_+HH?Ftq3_1E`+LW{ap@QODX3^72=aq{x6LVBtlU_v@cUy^47^5F7D z9t*xKrlx9`kWJLF}qj{Kf-%ukdhgVq2;s zrE85#Jq;M+iQ>gRaprrBhTbjbx0k$ZD8r#6&IDiz2F=yFV+V8x)88)I7pxprp>HRC;QAEit>nBW_mz zwfJgt4$Jr43b72D2emS;x1?{KuNdyg%Pc$JDQ=V;%(iI8VcsIMx8rF|PiG*t2pxAU zVW?NRaa(Zirb(2LgI!zwkMJ4q#b&KPYNr)%_X;2L;Q2R3jR$W3iX5GknhiaS4?MMN z>&8C25kDBs`rG4cDTM<1vBMeDH6C68e_pG3&%lbeV|X?~T|X;M-cZYhf+oW?nqDA( z#wmH`)=191(2ekAy*rMvHz(S-Z31tZC}eIf)3dHUFIpRXDTe;~ovnfX;EnUu)*QC) zxuvsy^r?3>hdS6{7Pdx!FQ(3XATE7(*os>H_U3PHeVl^tN*Y%1_ov>0e{Rnh1jO3$LaTD6?$cG znp`bLZpMzdp6zQ)$WXHS7=IwnX&wCx}p9e=uBFFC~{}P|V)ou^&fe z>+MSFr)fmrZ#r1e)x}I-KlKoKSK(-LmHTXPK>We|B8rcZl>Q(mF}xs|P(B5XuW zYiAS6^mWhR0clInZojJV_2xGZD)Unnou9_Gs1>bo+L-N@y4upM#K;_nwyUwZ6ki|0 zA7W4YAnOssc@DrBNdEwKkWln1(Rg^+a%h?bpOPA8+gHP0i|+cLZ_zn z2OfuS07&Q%-)54i8!xCcz%UFfzUM8<4`2JCl!7YVh6d5&mU$))Ntm=p=l|(pDB#f1(~Df%9S|T` zc--Q_EliFl2p0fKk4)QfNJ#>-DaZl9&Brt>1fwc$Y3RM6lj>dBcLSBb1jb=~a>F<> z!Up9>SxpU5;phO=MVdR0>0UU1om3MhbWu@J9Cgf>?$C%rCgk|soEkKaJ`3*HVq#+3 z3vJiBmPWxM(FREjFzWgVrjRwb<`+OpK>g=y-xdN4@B_yYAzP7n38>Gdm6dTwyaLR1 z#=d+Jj%*@`k^}yj^w|KoJ7C~L)GJ``>I!L9Kubbl_>6m;FX02AOCB_Fh?dpwM=XaZ zw3TpS8kAwIs=~s;qRh*&MKEXt0Qc;bZUvZTk!G^FnUV$@j=;NY zzGdNG(Zi6Wejn#nFp@&iZG+DbST_UwwlORv+zM)398V2`HsG&|P2*$@_w_{K@Ba;i2r&>Zbdnp*!cPPhM=Q8R;Pv31 z_!fAe4)qVoq$(@JP}Bw$s+SpLK|v%K3XrrOUz^0X6X2qz&;r zBCZeMOIkr(f#fg&IBWbY9uF}h$+@}1;E+IRe`_<4L=T^@pu+?!;%N@PJSc1g{SnX% zh#3KartsT%eJK9&#uWr5fNGHU35*Q@MJ46oeG2X>gnWtjMj=QV7JP>W_cYD6O3%g;{?P_u%*{{2bM5C4Flb;x;b+MYSg-<1cnM}^uxGX3rUG0f=ojEjdy zUYi`WI6$z=X_McdTtzl9+`LJsJ_(#II^NW6EyCE6Mx-enjJ z;)lXC6&-wNpb>^N8-QMJa}E_q;)Dl0^8g~-AfYBNp`09^R_ino2;M{PikmG9J)qOP z2VD+Ge2|xVSt(wY3I=DmwUnWC0Zs^^{{i>aK@7{n<2!KgB4j)>2yMVT9WH&bsp;Ec z2KZJHIDU9A5Un<#I+zk~FH~`86sHx^?7lUDPj{EgM2Gyi8&o^zhC9n9{g+8z4kVMmnfu)SWRPl zwX?1}qd58Y0HshQiiiRG!7od;7i1Qyo_8tBUA+i~cmRyVyuL>!Ihh#pHk!wvUu~wW zD~HGWNcQU7J@x7RG7=RzW`V^-pORtRs$=!(ZQb2oDIRUtiTzvwrJkMyL+T5b`Nty0 z&5OgK)p`xHAIN%c4DgFnQgWikCg2CQ4pb&FbZloi8`Dw9-IK%^^_+jkhVflQ6&Zqh zMq^*yZ@Kp2X1mN#Q>@n#9pQ<>S4wT(7}x1df5EtM>wU%W{wUlU$&99`8(#uvVRn}V z;u#-%sWX(>#TH4~-RpuQEbJ+BtTiT>)rRaVPx!;c_!LH(Q6}PS(>-w;O~X1 znkLO3JS0nx#`YB=?Ojscveh@YQs zn;L!fNs6@a^)1|5b&`+AKWgevT4$ZE4L(dsm?ch0R87U+);(wDR>@uIlb*!nYagz0 z%zno}Oz)u@y(*RVxuWy~G12gFU6R!PhU-Yml>7sQu@<4E_WO!FX}J08eRsYls>~eA zI~EYgG3(?qzWqtPLK1f9B4kxYvQPj`uo7hA8xbMSQVRyyEUh z>K@l;<>e0#JFjzG$Pxdp$IIT78ezG;S+n+$lSXLhg1z3j7UV){XEeW2uzNE*(;omKrTUndS|vP z=%ET+5cnJ`E4d$cU3jmTdAqyU-%(dxK`shOl@Op*(b|d-_}Y4U3Q)VD zoc;Op1aK6@|3HNQH`3xr66C;#)B%svJxhd`fL|N}4e3DIzl;=lm`)-gMa$5%_mN2C zLcs%dJOqK3!s~*WP)A!^*1#Yujc0!eHsb&YMA~m8{tVohMsGd2knKPSID!B*_&{Jv zu=h9c-PPqe)JF##)q=Bknz<3*pH&+*MgzbRvh2YK{*zXyWp4h=AqUJjfC4{!>j*%S zTC?x}c+v*7Y5{`2;B6VN4@G9$sjN_o2w)l@HSl5!naA+I+e3r50NJJhUe^q0fZf1f ziYWe}1c_kc&&T?vX&Nbm1eOWBu1{kRyheT%an0J4hqyGiZAu2|USd zYI=Izkf$0JAI0-`K8F$v!BCf8 zqZPbqpaRncC(YmW!ckDw0)_|mD&&i7?Ch|S9drU?3gQYizEQBzYoe8+&<(8@leDjC>L z-T~Kv6jb1nazRWSshY+<4j>gcD1eM0cLgW(N}hL?ZFWb%3G$pE>2omBR|fQqJR1Pi z;bDW_EC3D>l=k{HFWjn-JO<&Z!>_0j*R15$WvT+|N=#N+gPMHy4 zMK~|Crhv{A;14ix4z$o9(20A-jEsyhO<(I3cz)Z37Bs?**GLMpvYrnAWClx3$i;LB zWPAhr3O6-kiKY{E3kAxq6F{}@g9p);)POUA_kxEEap(dO23OSHo{KoSAe&8eIx8W8 z;G6FXVenoAg8%b#Grp65rrqWL34= z{njCQI3^E}2}!hthXZtXme6A%1~V8RmO9Xa>H*Zd@F4X9U#SX2C4|nv+73v3K@jn=wc(db| zb~!-24e(Xl#3XQmuKvpfchrilOyd8+(7mA{ zJx(-HZS2FaL_?=w;4FGH%;KV|fq}-FWwS0z?L9vxGbD*6;cA)oFLUwHJxb3mYWe5O zg*$Q*LxQ#g_D?@^upIKoGkiKquE*8<`|={0JXwqS5)(LZ?`2_#nx{%WZ6yb31v9~!Lu4XPgY1=g;}oIkTI zGGeR|yot-trLUpc5D%IS9k;b$a5CBvoOEGLvL#C%+u@)F4X+pm^$_JTu&mHj++br? zk8u^*48En*bbE6qAjs0NxuUBqtlT=J?mGEs-SN=v*w{zcYE#ktyHp2ONFy=7y`ug7 z3pK6$k}?%+D~o{_f9SFC-M_kQG9;CAmN zrnfO0W;*-)nab{)jeoirw|H@GJ}Lb+gW^tyYG(~svDz1z>p1&nW=84PI4 znUxhOShdS~b|mc)Tz`Zc$+ZzI6-^N_7G>7NO!LY6dAIY^Z$2(nqJjb^)W|8LiJ5m zDOUgcjk=h?=Vj~pqU^nBbLxBPm1X(ss2KJ1SRJl?g4R}%L4{><4U?;bwY@gpe{ny% z7ylf#|BJ3!$2e;-e&L^>;?i>?{#w)1%)qgB-Xt1|GX6hB`dqCapWQZLWorIbWapQ+ zaeL9BDvy8sz>%Dk0%Ps%*hjrDBk3ggDPP4qhOHG^*mvB8u`7rBxnIuW;e=(73_2`l zi!wGV)$@$rS4!%zSx@@#%TsbiZ&f&U@C9psigk=Mo*`*`&3dWfBIV`eg1sUh`P|N} zbIcTdRsU~^#k&k7#spe{%@k(b!ynXezR2Oe8GBs#l`ENy5o`N>lcWUy;I>HYZ@Rh} z+9{%yBGr=5J*IqG7eY0yOtiD3XXk^V7;@$L&{(4B; zg3k8epRBG`W$80gV00^B9|jW!=$4kjD8Tj~BBQ3p``;T~IIiFvL^P>jRgs6|HRpvT zH5Tm5`vIBGD^L^|m>hvSfyM}62xM8K>FfIe`4P5~rUxnpPCVkfha}OKbx_hcSZSQY zSfSbL71;BY1ho<{FZWQ8lnF%h_4H+oQYo^Q3kkA-S-A+n8PIe=qi6-ESkjJCoE8@^ zuV;8&7QeZMDBTU_P$ZDS>(>`14t+0E4sd-!VFy7pct^#`A40)K4msd73{2dk4&K3O z3+Ehil}ne$E%cg4U!KF19|Fq&;Xo~Tip=N%k-$Pr2-N~5GP3swh96rA!_zczB>=1b z9@U1ZLuM!vc?1>cvo9#75buhzCmvP=lIH|X=q*qufPU8RY_kGNkq;ce?}gP7UJK~h&@k|_>F|F ztRhLD(Dx600Y+;DRwAu|Kc#o;NPId`?;C`W$V?xRPN(VB3}z8pK__ z%d~_xvxCnb@9sVeI{ut~D@|2Mru)LLVCcZ+^LEFy@u=9=dCeP0jDs1I2dE~Vz*tGh z=l235nhcIC;9>~uV9dV%Css2s91tSLC^tDClA47Ck}ZQ&grFqKs-7BRNY<*%FSrqs z2$E~e8SkIJd4RMBe~^a2P3im?9Hw3nj1z4Qao_-7u|$L+G-d#uShAbmXMwcL5FLpQ zL)Pmdodj6XY>`mapiB^3A}Y$@zr*HG@~@e|?w30YZs-tH11!em!~~f_3q`yl@+2T> zcM#@ZEqlcOFGs6fhwJ6$Ej_gXfc*&U8Vx?ag_woFn>7NY8f?o%4C?>siXc@NusBB0 z?+Rfg0yTDE<$XYntch1RDjfz0$*j zy)~EMMgVWoJ~DHgHVY`;5{IV+jhm`!=>f!@*3={+XcyAiYmW&?nubCO_^0NZKOxK+ zxc9{X;sUnu-@62>(J9O@N2*)CeT0L72FC}an#eLu0O5iIbiEM1?%?cP2BZ=uJ|%Jk zL9;|xwn8QcNx-epC%_^Zgk2uV6T8KU)&=T)< zk}XJz^5g%h?1nv7VkUJMM|9A@u=v^8SonR>JdV~1w^fYib=0_SJ_40Vs{px3vnrDO z7c8l0P95=olM>$y?axgPw3H~2>$$`nAG&1Wdt37Om6&K*pwH}69Gcp+dC za-vAuni;Q~{L($FB=$0@m`4i5gleHgLO{lHE1iWh{bjfgQ@F*l2c53}8>x>R740&>BY(6YibAC&d z)BVs|*MI)~keDoOvD#=K*3Fud4>&qSgin~;i~2F-=Hr4o6B{J2K63pKkag9H)fgY^ ze+g@0YTgLWn{0by(-c4XM1$LdB~s2z;1*y?r%ZIyL|jI-=qrzc&*Ye+ejeY`9--?T zE51YQbNNg=q`L39Ssxj4TBW`n9H4{OcXgvLFd=Q`1; z?HY=j*4UXx1(Wz2;WTw~g-g_popmeE78# zk^LOR#wv!m+s)+!>RP~*Tr`-cjlKZYJJ>HXI(%02RIuw z{chJhJ<-GP{)YQ2jNms3KeLwLI~7?D{*FTf`Dz2~NH(o1a^`CyUz)myd_ATf`=rEe zw|`B1X5MyC@sh(Tx)+xD$*c4JwE zO0cLVbvyk_9!14*ldqN?mVD8MR^a;=J6--V^TDKq8(BTQ6%a9mWMUvn<@@aNE-pJo zKgNR3@?h@?JpLdQEC%iXhB}Db5OS-a&4pA&I1XSv3exg{Nt0Jaxe2n8ju(B*pFtRz7LTBfa3aZ^@}~u3^nj!A_3DdWErv{1Le#5 z0c_a~CenZvOFKK|;I{$JTL?H}2%SJ|_TZ6vecns2WrPsA0q!MO?#C^`42jSz_8grJWl=tcM$OK3F#Q$SP$d#N%G+6b;Eo0Od9*=mLTGbnB z?e6;Hh?>Vd)d0{MlR963D z;^m5A+Y4u0cnPO$v5UFMpTZSUB1p|dOtzbo6;PBwhJ)Zjkgr%;&u&46 z7rT0n0>r+8?$jE@t+GsqR`oNQ=}%1Wh?f7JPTZew0zC-1A`SWYkCVi$x5%#qB~~P6 z2K?hcnBN5cA1y!&Z0Lbj)_ct4_+$lyFt}>fADSF!EY%0oY%G&)_^}$}g|1L^*pQQ^K zO29@F7rIvO3#sE;9&)^kNWH5I4+yZz3p$=fAXXoR!5$J74_j(dG(R#W>y&KyGh(<_ zt9Oe2if4qUt!&Z(z!m8il>Ku}HL;D0A*u=}_V;0BPNprAk-7&D-WZH7xij#jd=Ay` zX2X)Z|3=~fG|LfjGKBDW_!QuADcKT-OwyzF(C$vSsQ~6H!i%(-BmrW$tAhE^)b#FP zCxmo?D=`}=7RXseBK}}KC)~u0WVVg7vgZYbYU9tE=V0Mb&WoKdM{RQd`H~_yq>?FH z$GA^V&(u)$OU&dfkaFa%O=`dpZQkfN+j~sE)rd(GK07>kT0LD(Qc&ugTsrJSmSvv$ zb0a*?yDeh^P2V@YXRyTh;{}VJ_W})`nKXeX+5@?C7v73yTGiY$)5g1lQloQo9^8p( z^2>pQ!sOb`cEV>Gnd`r%`lp_mR4xmY@0##Sm@5kte_XlktoKD_AQtanv~UHz_UQa^ z>+a^0VGm{D%b&s9JdT5wQ{Q6F0!1Z$W{T<8lbpz$C8Tl2^BC`7FKIMe5)!};t!I+7p=*8Yl|yNllQH)E85TeXJ11;)Qf??!a@a7-30CIr?lgrU>m;K`jU zW?uXln{{q3M0o;&KjauG8mrqSr}ZUU;Fol!#zNRd2s&Yd2)Tujs|!Sa1sT4g5y}$} zq&p&aSF81k-nrHrEvc?t?R486b-1!cl3d_ToSmL95Orx{iIc$EzJm36Ok)j_lO6ex zz<-?4oufy)(y5p8l;7}CG0vW)Nnb_Na!8BAUZ9@LeQHOI!;WXt{Qvw`Ph86VY>XaQ z6kfGt6ztBgQwiLNab42#7<13-uK9jeyxY+XOuqcwtby>R@k_s#xApxO3v){au`fE0 z6;-c-gF=?>5Na=NogF`A5LuIWbn#bc!7He*)W_-W(V}(qmd8?a=lPSRu9J1u^7_Ro zKfNKTC2i?1IeYtCo`?C+Sv^?bIX7FGpT470KjHsto_(p=#UU*uU%IfF{j+`!@#U?5 zf2UhUFT%eYm+sEZ&IKfDE}hVslN{)xT!Pn~1I$=!NOKMG(Jv}0?4B}bU>%1oO{Fxy z+I@q*{<@3uDJ@3`5h0yHS7CFr0ntL&#Ti3JfBS@8e)Xr);%)Tk#jW#WHUVdb$p3yN z+_>l;twJGJ|ZDHzzM$9CR6d6EIw`bL*;u3@uHRU2eoOeXGTq@cATW z8`d!TY;5C1!)QqAyR>G^vVU=`#Yps)E@@N`KIPx8hh~-SowZ*ZthHTz&Bx;Vn2{(? zQ1qV^O}D>&+X{ydPeE$bO`=1!O((1Zwch~^3s-^!)-!VNFCD%kO+ECFPk}quhQtjR z;P`G7KZPYHEbo2rT!xo^z5TYkdXn_5g!6|E;ofFz{=!$+gw1wfba^*0|Bb_|jXLK0v-Y{`*w?)tQ!)RY2`^m*zKx(YeGd z`O3z*wGvhgB->YYM{mYYa=z#FXUu%N@H4Kx{+OTE(JaZ6-lbN3S+P1Fgo#LBfdrF* zkO0)A&{kabK;eR>d;Z0_%cstrM4g$by98&W#bE>@3{L|-PfeXS*Cc-5o}Zi9RWXnL z??>#+5Ft&E!$?$;+Vy?W^LIy9pyAxUrJWr^MgDB($wYI08^dx(BmedRF{Vy8x-Cdb1wIEq-bUmLpB{}16Am{RLP{z!&_J?5 zp%E%gWtP|1hxBEr3?Mfq;1-Z1hs_pWH82@n+lV zT-wcly+}sBPN0zgXC%g#B|j-ldGp~1nBYMu{*8rDUqVxT;mMnSF)uJ4X=rW2`) z?jn2F)TD}&e2|K}=^ai!ag&;pFj!LyGd@@zI5Trvtor@XQc-OPD&*;qX0RM=Ie|Hh zY)96Ae<%?E)s2mAP)^a(Hj5oS_;2^8GT4P)j6(K*z8ch3xFFURU5P)HllA}mWz&R< z`JdPM|M>?oy=BkZ|NZ|bBwTDYTWCd4n9_`&7%EVbf20rRbD!e$+7Hy|{Sd=HulB5w}RPRlY`- z-6$cAn2s9o63T$bYqoxf5`n$x;x-SfO@*9VY`QTS1rC$9Z>^8%otdojlFDDnCz(1w zL_ZfT@x(RglG82~+}`0m4LUbFESHX!*s_cM-_`qMU2@g2;+T>w4R=#lFsDl~5Mitq z-qa2pxM8yVm5VBo{tp+U(k+oD^aRf_wbps zpb|#DAEeU>?n>}Je;SZl|HLBLUUjOCc-yr z#P^NF<#?{MqGG;&b9=Vp-nX+kT<67X1f+?>zTO=6W|v((B0?@UxtJnqmsfiJngzz3 z#yoN0T^@bn37G+P46_4$O=QM3#8^dq|vYIz&m!?L&TnBf+ z-MjLnzf*+^`VxTmE28A_eLm_KunxhHH<~0IngI9L^8b617Y=i?e z!*+_8U%JgeMb0$F&S}4*|0S~f_Y?Q2w)Ah!q_OWALd($OSOmPa$8#_o<=)?Ho7o}r zXp8V%-P>`Ot^_$RYJoY%OU^bkLdt8;j%eT4DGUc!*a+suIZ0;o9+Hs zUDWof8)Akf;{H|P4{8~O!}8th=S@ah=oCtNMA%t1)K-6A2} zCFoNVYn)ch_KR?oip9R)?mQAxY=e^2X?jiWc#at;1{NoxGp_eue@fyYONz6UA!2TK zd})pO@ZtNRw}!i|GuP=x1*ON7s5qCaE9&tcW`sAfc~^W0?Cs|o^5J+pf@=LbE|qdH z^~2o80~3Y`tUgL=UA@D4G}3FWMzI~TRvjyheGk?O@YQUFbDoHsZYMGIJ4z(?{h6E5 zm9N6;(}H*W7%R48>iyzeFgaT2!}6l9{=3)L9I+jUy9_!xR$0fn#6<@BNdZnqaFbs- zr2TEjox^U_GpIK#`%mxqt9xI77S>&`l6?mf&=x=?5^=o%>$t);>TW@2I1B$O23BWC z)k2nmQSdylpf3HVG>*;12EEi;>q^oJc->%d{2Pf`-oZ9Q0qquqK_U`LUdn*3PqUUL z)8KYT7IJtm1^J*LQ&eP^l@FC5^?o&}TA=e&Tys?Q$$SmL=-{3tZ!uOkZ2W*6)(19FH}_6>t|5IaF0 zFV0DY;L7-5eUWC@?N^l_VWk_#@6Pkbz*l_pm{;mS91QTtiLNHg+dDa}l*IFuz#_~q zg)HSZFy#0LnNUzn|DzBAV?3C-@BSLa`)?VU^`l`84V6)lQB{enoll>Gm@zc9t?f$< zvrVx`r=$eZMfqB4OlFK3nyjO&!yd!trCC~<>qei2P zog|D3$?@bJWJzSoOXA~htkWY2GNAmA9 z@2*}@l$7TkHLHN@dY|Gg?vKgH>=wo!Nb@pyKN=#3xG% zgpa@W-^pC2l(Y++zM|l)X*iD+ooC1CC_3fG zFWG+HYk+p0F-R%>%X{r6K{QedtyDuN6QO%*HaEDKQjGJi_;9#h<}k3;7lutO$0|=` zv>FX|*^eKD_`NRBrcB`#=y&>3K%?xc#aZpPVe)+7vCmnpF@96J+*fyPt46{@f7G`e z;-WjL|AgX4N)=y!(y%2}K=V_fr(yFl#E2U^P)YfwtYoJ?B&lHMR7M<2mir>$ykPC9 ztz)*lZsf)Kd$b29L^X-Tx5y&7RYG`S;}ELcgqK0sg3iGORgx8&N222HEzHQ@JBszD zo`E%U`H0R5Z@}H!5cn15SRf@Ouv+Rs3DXQd#O+=t&GMNrC z|8;u0Sqg_uFZ}s+9O2F&$OQ)#vt0#+4DYZNv=IL}Sit2()!qIi`TetmY^aKmdk51$ z(CEmc{ufNGE+rIqDRPJZJGgDyvX&CtaLdEuj2Ra#Z2a}L?8-2f4qTriWtEXQl|r?1v7UmQ#cnLA{b{g^aV*=VIlPo-Ez zs>0>wQFJ>@%1pNEUMSWqM|@1v-Uoa|Eibu*%CwS~4q?LYKdF6j#H`k^N=Rh>ftOo5 zB!eHJnMTFb@Kos!3)UG%#y9NytYm_;H+XS{__ZlAa`3HKiG`uCjdi=6Xoy;?YSwKA z7~v|IX?Z!P|C+SgsE%9b^^&cfOcvu-vUP6sqogLJeN%YYuR4zDU8|s?MN4srh53d% zEC%<9W(MP0e&@84{;$t5X`NNdBK|Fe&)%pOec@FiMreAu6Wb#fV>1K_B)3Zg- z!cQ&3rBur=(^N~^tG?yD!F-IHp041eC&2EoVNHJfiu-Ar$5P1C%g-_B533oaj5*%6 z$9Xc=YexRl89ulACv`TWh=F(2!1>N94BfN)`Bj#jYtRY(&s|C#HQ;U$SeGpwB89jh zN2SiUU#_v7Oa@a&hrMuUt!dX?wL01N$T@d-8%aXp(bScDsQ&biYIDIN#YRjIb1`HTlvreNMA4_P+M3ySFvm zUqv;sO@wW*=2x~p;egFS&XJspjtb?qqBqV$9S(Lpg*^xbxk?avdL`K;X zk)4sU3n@EWvRC*$U+4Ax+4?~4Qa30Uvs1+RE<4&nUAItN zJZj9=s??oA1Vb(ovGyvIAgw9QV#Wk?r{rd)efiG z47)<9#eE7zlI*G;9~8e;=)Aql@KY7%*{p&CA+@zS>++fNA}J}Ut;K)x@HYHCk{tP4 z@k)3(ZY{fA%>4@Yd8hUY?OT!Xm>ZjOf3E4I(-ikb>3SL`Ek58b!hQAJdCZi%0()15 zdQ`vkiVizv_vc-O*KdXy=;iL|mtU{D|3Ol2Ly*MEIkV$5RXRHl4jS2MgQtTKza!vi z;5gr@eXqrD$Y$%ASd`ZDMZIgqeUb6vUzjV(g5OQ#ti(RjlK<7hNjrwS7V7^Zh3!4D)wII!w8D7A{BpDT$Y8ZZ=-;f8^Vk zjm7TN-&q)4CmK9~5g0fA@P}X9*fuumr82)h6C+zq!IQg5)|Z5r4h^vDzsh76>w3Aq z7-h(u>tv-lu^HNu#wL&)*oPBL9Ucmu z&z>q6&bFUJyMMTZb#z|Di)bflw^M(e$#oMK9#JDXzv|rj#^`r}QC{abiJ{|dh2_9V5 zPxQI3SqJJn>OtopuWW6&@Bxh1+uLz{3)qFAsD+F(->=QGjZ>M5y|z^{W{Zbsc>H=i zO&~P|79~OFJz<;P)8i6*n667n)dzlbNPqsC$lKo-)_kB_4U=Ul zJUg}UyG-~xLnlkd*9V?4YFu2&s-rbvg;gR1f$;gU`{(}7>j{d$ zHw7ty34IHpi3@u_-O14(0SwjZB1lHH77{Bw4#-d+BGq-=+1(u+AHP`5 z6|cM+XFj?|A1Tjj*i+-A$4MhVirXfDGhL?5$(+wn4XLaUt_O+CVa3J6!^0HM6vIul zN>_Hrn?RnkIlYPUtixTy1k)tgLI`ZS?m!{7@>t8es?Y~8_FHrfn8`!Ow5!Q zMBvCU#SAybuPW6z!A3Cj00wS%2j+u!y}Y~5FUZQ<4Gwmy;YBOB(v_CyVli#Udy`CGGH>pV zhN5U@B*+*e(-`_Mn+Yf(p_6OSRWZ0UW4G1(n%tY&*e`eb?&<%=J%IBug))~>wY?67 zjtHoy$!@HcTa4>>e~0#J%3Q|Z|M=D(8?0chg&^_~oc2Zfj)s(NfrwXyO`woZ%FcSW zf$zdI;3D5x^?^h`z$cqCIosPWdNq-^O$YtIVULsDzB7sJeM-qE2gbI$qFtJ23^*}9 zo@H3>6Z!j5VHt5*iT-CAQYAIQ+R3}3W9(U;ffqD%ZpqrR)C zGV<#YvGbjX4Q*q`fgR~!!C5N%Yav~iE^LxBAMP9GaYcR=i0JL7`0#RXiRxkaQ@Up~ zKgkL`xec}~<@fqua&|XaTa={aX$f)1>Kt4(U7q6Xt`l!qLNcR$l>nQO@02z(Qbv^V4*{qp^(i(V6w9N-&c&MHlBYgg_-)(gs$KVKA|DJ zDn3kC;_Kc&rX#PJ?LR)ki9h#jA6wk#=}aW|#Um2t%DHdz%KjoD>H1w#b_s#4o4O4< zYqKvfAL?U7zQ4)o-wDSV>r{={2x;YE|8V+N+t-%il>xiB#;fX|IJB1GirU4^U#0ij z@HYdmQ_kiLI}GitNDOpe=&+_OQNTB7Cd=wdeT@;4DX=vqWyj9H4xjWn*L%i~g_@VN z%7@-du2>y?{LK0>sK6~&IPN5SP~U$MN(##&e$>^ z$w&UXdh7Qc`SG6=%|Z5W$dunE^JSFv&2Z$81ZFsipFX!!cS;x~eb%lXniUpy<=!h^ zJO_v5#S`uy-n2#uNhxA9X65Ghjj*mrsXgP8f2af5t)~y{Jw>1QI7M?^7+~|PR^)Ge zyZ=F9I$K0FnYpGu`0i$Ic3w|P1y4a{RJ7iJF7Z4^S)=CUuze+N^tt3N>0DEA!sSioL@{F`V158|$8DJdxdOR|7XPy(E2ij9t$8M${% z49c2YVi!5dIqyou-*F4L_6t%6C)n3t1h~LuTJ%<#!2>Z)JOjF}j1!Q~VlLQGN+{eEmX3q{0 zU(FgDCMDcVQelN}>bv%>%&@{N&I@!6a<2n67EB?94$|(e;S+&~3ludDBDpi~R16G& z+MywyIS;~MyPf54FMO7ZT0B9nG#KY}f#~Ea3>{S9riSN(EwLL*0EjV2ivb&vH@C~6 zbAW3Rq$#pP!tkEwzsv+(=qFc(7_oUekko}Zqyv#A38rwxGp2Z1arN+q&-Z2bfs_a- zhruQx0o@TYg9DAy(P0-}S{t~3(3=C%udm;{d5WfXFz?=5xxvN$#6kev$Y8O6UzeO= zYYNtSKklH{QekKfea`Xy?S+9($z4I12a<#K1vrya$S95^BY=|(eM}HYJ`PF^+#@Jc zpcCUIhb;C_70J4l4nrJZEQbUSTkuSX#3hXjctPZPJ}(Jn{z&|7 z#?2dAidsr04wKc6wFkKm=phlU9t{vH$L=@nogMK+LLeZZfo2l?-c)&!8K593(&wg- z0KpB?1VQbhnm-C&Cy+e7HLSoHuLY$iI-&+{B{@IhX9v3Onj{;W}Jjm$pzIJ%+rg-h#y(Z3)x+Jw@h-0k=0 zsP*jE4k(5maohkS!r1Vh-?-MrP>?mdZwOQ@&=@u&BM&?!Z*G;Kj5gdt`$>2{VKafk z23J3Rd3V?!k+5R(u$!DV@7m|Mn=oIgCit|NaxJuKR1fQpRDy$H7;IQ^Ww*FYCBHgTa& zuq6*oN`|3R%xx_hY)8EOz)?}VdV5c9M^#eE!g`&!f7R2Gd>4%?`;y}hQ)D-9%#D7! zSCy_Cx0Vod846Yi;KPtnIDl;@S+jDBb zqosW}kf9y+UR_n|Mugj)8yc9@^>e+_jW{PW_Yq(s546`f0?jtVaa?DVQ@bCJu_LFsVvU_rz|6X zMBn!Y!gDMxPxP(v-?dy6H^$$Mr^E zUObEV=(XhgPjA8O;PY)1MuEr_D2-BV`Y|==HCwu8N%(4m&>xrvixz@F14IG;`AQ(W zl{r(>icXQ@cLi9+kj(>_227%Ekoe>lBRWH9JaSg^=Q^0ne!BHsks;<}!2=old5Cf| znrcQ-9f-t*W=d$YLcyZTNmC0%rUMis4PL`=88Y(I_GMtFcf+`v>O}w_4Ka!bW$<0C z0KejygBgFNFQ+H#r*3ocN@R{hAdT-ZvdjFYCAyLMEY;YJ2$>_GO#ie7Kr zI13vlF;L_^VI()x{s0CjbOXs+AQ|8CrA3QFa z=ZNteHj4#RvUl2!Q=- z1{fs?;Jt%bRF69ueNHge9!UeC=WP>DS`;J-^_HDLBQ%pJ1;h{N69q6tccD1ivIoTh zSoXqUYg@WPLyQl#&Hk_TfX?ep2Va}%M*;1nqr~`Wq;dD!;F3p^UKg?uz&=?5)Uxx; z1x@?`e$ymoQ?S*tgh67vWyO^KHjoeM0Bri`0(%6NT^(S! z4q{LD$zjq6X?&dUdx1s36?`}D-wbYpQl4hiHXTakLv94rL?X4q#T#-c>)rf5tpaMv zf&bALc;b$3k$~QvwHE&8e&BO-11Kg6z`; zG|A%Sko}_pC?7gu1?dVQDdBu+-`_WWLNTP~&kpzQ0fiM5p767YyLp|@!?O<>=L`Sg3EHmN4HfSN&>?<4#g zvnV8Xt!7*l&I}~9Mod7%O6zXlJMXy?y8%+ue)_NPQ~rwyW&CD}LMyydqEECUnpy|3YTA2a z7$EqDE}pwGow@JD*gaogSP@WY5duR`&zUv|A}T0-@HB)__s!(!P!5pTol>}DVj z%Tv=IFSZ})ROb;Zl#kTc+KMR*)%-Y5X>E2T{>ZIcjYQvEjr$^zDmH7qH-2Yk^z}X$ z=@No#__cAqpJj!;x!m-XwEY!AiXZ7{tybZ~s!rgLQ24j(T%!v!qqZ(ZW?oKva=@01 zu=L5J#MMOhO(=>=<+FwVO%dZp!8)UHk=8bkms^FpOF5eU!hqO?!34txim55S9Z};r7dCckK~?P&+>du)!%=A>OVuSa^Ow-7@%c7uRG$1BU2Q5H zdTNiG2?a+*IEKI9-F>L5b`e%q{%_BxG5hAFMme3){oQ>nK``|)?2+TxRZgyv_*TiS z8Zs8NeKO*SuOVH2S?IkK6*px#_~zTt9!cfj?+j#|Pu_pe!<+Wby?@N5GVEA>T~$bn zMsS|HmL1av##rx+hjCXy0yFq$p9ySr+CGOBZGz z-2I-Bb*oTP{7Es_{r!jhxiOV4to;{gslqSH{19;P!6{Fm!Z>;kI~R=YPkx-(nt0WA z(%1}ColX#NA3$)V8YuSTsCooQ0sbWhLFe#|3~1phcp(d|cHH3ay{6Nfm$Uuc;(FVK z0R&A5C}Y3*tJyhVO-R%fkF2oj#KBuqWauPfcnA`9%co6^hgrhCzuizqDl!||LdOIh z314bk<=d&Gb*+F+YB)mR=5ofd%is_AUn0C>jy7oM;J*&@2Q)K~fY4KvCNA@#NEy=# zu!g(65?ks)FEzlcp{LiDW|yTRacKeCf1<~XL&9|@hb!p1frwW~od$D?EFn|~XbKK1 zEhI2SjkO`~Emz)uNerlff!{|6tY2^~MZl%IfU9!~_l*GDdj)7B|H&DEf}TOiX&xfu(`KtZW& zkDCQZp+A*FhS=jjUH}cT6xf5a5Cwh%8X7i%ZeOpB#Az0Qa*;v`y07qAQSj5ilpJQ= z!HaD9bLsa=kp<)VQTl8AC12R!GyH?UT3b|u`3b5PYoQ?@@XXXLk_XYG4h@PshjzCP z^62Z}itzro(TN=*at#tz;ogPoXZxiFSXu0c5cV5NBB}dPWvct1l08id0B;3WMgec9 zrsLgL;P(S7MluX;P|7c=7a%7dWqgAD6y8%*IQ?nZ(xg!U=Pg{AaI?XPAx2_T86G~e z#_fFm`2RGm<24S>&I{jen;|t)%cW^iO$wtel$8hLewxY&Q<%!sLJ3m`6(=a#z`K;A zEj=U*ad4VQ&#HjsZS~Nm>)rSj|_qSP5Bhg!$UOvU#-Ahaz-gAaF4+`KrcM zWkFA{r1S%u0;Wh>0O*=hC^Zf|-a%C}?3(#fNlIZTvH&PnSP3C&8F){koW^0p;)05} zL6yrnD?UxfJLPR;uMBy~Dq$Ogcp4{1O{W>SWTG!8eKmN`S-t=Da28xIlJGKT>nD%- z0EP!oxEqv28xZXrM?}prC8`-$iDF3r-g;5~)j(7q$^VpD2YXYY?xeEb1xYpmJ*O=< zlCvaOW$#Nd8mpXmlGtT#Ybswft;VDJdFgywB8aMv7ZQg9Sm$Z2Yd-pN39^gUMdMa_ z1?^je-gy-VE8JUUQ`Rg;*Qo~#8$0m7QD7Dn`#XyaFER=RUYVX}Hwf>rKXgl9p8fc5 zFxczoPFC2*WTt|DY|TrqOa8gi8;lnRm}~yX^AtpNOPXO$D;F>}dGB!Ff08`)4gdQX zKP~=Px1gK1Jyuxx(gS9v$b*_^?*p1+w{4y4j?x9!m3-+Q$Dz5p@#&tAx;wel|0-BAi*ci31 zgbCip6x|eil7M;6DuGQ#@#~ZtyrA8z;ie49^QYR zj_t3r(Q%E;FDB6IcB> zBaJN^7W-P5ub9}GF_I|mLh&n#v%gigD_HGcCFqApw(WdRONHM*;1&&qjJdu1*=*l-QiUj44 zPef#PPRvS8axi#`Z=xcCfBl)7&p&TebxYURl)=AvCeq?sNXM!2kekMplF34gpSK6| zVxpKSLJwWrlh~*#)zrFV59m{0+cyZQ?1%|c`Obq#fK7_D`vj!suUK66b?NYsnLDK76lgXp+(#$*FIV7o5d3IZuw{ml$|1H~1Rd z+}rxONBXS){&hl`O~%X=XL&I(ExpIVNo*nEV;^?%iGQ#!=q^nc`TiFfyPXyjHs ze7VLyTNEEpvlysWurh$jlzKLFY@aYk)>bM^$$WRJWI&s>W>RT6l(d@wx9av(v9{|o zJ(frJ6CRh3^INRPwvUS|1u_);eobd`n#??8&%*n3%7M@L+OAKxDQ?Z2)qI4?2tQTz zPny}qrgoJFS1DdMU%kcRQgiW8(@o3!iGc>9NI?o1B zl-VAeLq-IO-h{J60Jx!`YFW1Kyqf~tYJ|gqfinor4hCWa0FD0D&YcG5nuk;Q)6Nk^`3c{pN|$ zJa92gQGi$>Bs4;e6=X@E_7q_f;UH`x()$jeAatK19xBugP0-#$WdK0Pt&i*eeenun z@FO68PUwjR1zGriQycgZ*uEeAfRoJ*?embvbsea-23{%E!RFu*OgISfaDWK{t>e5m zw)Mg$x}0JE0*5+!>tJut4Xvn(f&zYu4e&?MYvSbL0Z38idf@31{9Nit`Tf+ugY&Rm zroOC+LQ%Uaqiu}jb`}2zAOCM^0l06-5)GW6Y6<-xG8FBZmORpd7GF({BE%sJ5Yi6! z0gUnmVp;$a2GCAOil>l`%g*jYww$+bF~BK;_ZQF#)81AulnFR^*x-g70Z0@SPYB(> zOcf~rPEbVQKuFP&63_1eBIRlRP4F0$!{8geZDRh8_V!rj@G4`rhc5@`a)JqX2yQV7 z%ZL$aaCc$>Ar073&AXLwFhJfruE;b(iiistA{zHArC3{-+x**o-^8A&>5ixpu*@iM zf-B5#RA$Fb+;cf+KddlkOrO|H6RAg)m6x#{Uoe`#$$-00zz)|Z5e;>tdwFC z{`TvwbyPDVBmTgthEqA1~FFX>4l*~OC>Kn?{z?th3^XZ3SFSgP0Y^L zf=NHh2pq6ag5?yFUH1A#lRc{MI>UVSg0bQf+UB3DhmTM7!P_x0Q~YZ z{D8uuA{JP5KnZih4mMk6y(kd*>sQabrDihCNzk-{1Omk3eZ>ZEz?;o@;~`kC&cmSR zEAp8NJ5V9uFc^gPk55>v4t1;~_t)*h-i~yEYEF$c{w2Jd@WSAslO;s8Is^g(cO*}j z3r>KU{WiaIpiNy7N8}9mgs9`{Yx%Xx0jEbN=9 zQbad~bk>C_6GRn<{fw0#?zDA2!BR=q-lDPAIILPr#1LI|%Gpcb9X*>}AY9u~{Fr|u zXTE!}grh>kA=;Fr+m6~`Y45DhC5>DIPJ4Etx=*GVWpUDmI3EY{ygU=WP~viJay2$* z75WqX?HlO2JO8dbiXe3VyKssFKW5A~+Fs}mUKVk3!#zJdO;-zc{O&aZRqfWBrZT%0 z7q8k8JTsEd-TQ`dYfhQz_AqQWRbi9u+`35S5a+aVc`;R)NMh$ge9u45I_4v547nkZ z%zcyeeWt4dKdB9|t2%SNtf(w};yzM!;}UE@486t-=^&Ph=oE>iH7^aO)oh#EcvX>1 zHSZz4p+m4;WfVi)eK(#5>z}5Nh~!vZQZo82$y-C_9%Q!H(`UR8OIxC4UZQxuE_-%7 zuRxh7`XLDqpEetnYP|Yi_3M>^FE?cf)$x}6IUIAIcr_^gS!mdnpa_>GQc%X@cbU!1 zV{xELBcpy}#%w&rtp^~^hw61-B}rbHFgxJ9Gwisj!jiC}l)#?LU$A=3i|RNbgu?$? z*sa%7*`+KCs(5`a)DiW-krz_Zpc?m>Da#i;B$#_k8vOqIvpXyn)pxg#1r_c)5f>5a zGA(-?XN0{`ZMmuW$)tqYro2y*Uz_i~oAQOfV(}~|x1MB8VpKjBj3?-dPt_2K^YAIJ z4{d0<=O|kHxnIjI%acy*SEV)nT~lpv`aN1>P(@m{)g(V>PR*Ju&Er%nVeSEA5#3D2 zkn*FD)Ri8b04SB5%e&k(JqGS9TuZ3KTe>Z#czOb^O>#v#0R?v;dMC*QN?uKx^9v zHo-9gM7(Bb6+Jp;Hw ziHV7AFveY*Ah)hv`CtMH8bwB<6lj2r`BD>)mmwi!bsaFEpqT*UaRW>$5II1b6T*D` zPLC$-X4)gsgx=-PRf!=lNJ{((rZ9=qe_H^x-h!cU0*Y~mMhb#1fIvhVq^FSpKM%ch zxZhfUQz!#SQ>~jjG_b!$BpYxpc!X@gW_lhmlK?LXq6M$UEk+o<*G`3yAQ>W6WdF~f5_z%p zUhDHuEv|QS}1poW+P~Xt73$%DzuC9XU{)Gn!Vut^9i~QqSo%ommBbZ6&)Fp#kx2VX% z8R=g^rSLDHNiN75HxFyQ@rVx!5ftr-*eih6hui<)H-BI^d?H+srZ!xtqnS@WdUQL* z1_lC!LZ>Ti24U)khf>5K+U;`kO_N9RMj$K~_ zC{t#3d=R<}4Kct<2v~pg2>dUf9r&PbZf+E^Abk7;(}%9UzJw>gx3k_13E(7dAy7scF9w}!z4A4^nm z)cWhd(cvKj>)^uH2F$?*;87^lc?Ib;54o!<(~gbAYF8&v&JomeAihCmCM1kx7PqWQ z_F4-;@EYYaMM_x+HxcvM=7ef8o#V9wYmZP=?mvSwl~u_$=y8Ms5o!h@m>RIgN6689 z1%i#afHI7>q)rd|fNy#D<%eB52Y|z4B9CBh^9hDhpMVexRNv{AGfa6fvx17emxZ>9 zRuR4vfVj;9#3u@m`3t`s6(OECQ1(&0GCbix7KgzSc*EL+HDI764e@`Vdqooy^p3%! zQ{}mGs}ZJCf1&7#0=G*6{de(+Ete*$wagAJD|4J=^;KYk-xu$#d93gsf*v{gNB__g%EK%Pm71 z&nIH8FDpL^Y6-Hfzj9din+xw`;+b>J6M>x59pTKTh{QkxEZaimN@4j6i=)q^w;{@e zJjeZK7k1?KgU~;hXSCf`uXRmj{E_9+5%6U>45zszvNK+p|KzdfYSXuu>Mc6~5{Z%> zG|J`=?^I9z(klMyl6U9iJu@jJg3$dC*3`od z^;;D5GTZsW8>v?xJ`oARF4Ya|7b15wRLpB$6spjz!LFl<*M4dfV)saST~JVCMU-kb z?@z37eGP zcyQ#Knx%RaCLWm%v!-oJhBm%aiPr4dwG^d){xYe@`|~vn?OWmP5)JBK2^~>)YkLgb z`Zi`VU(aN(w4N3}9sMrsCb``h*L;=bu`Y9b{g+U?=%@+Doa9mYCk)lUe-2)GedT;I z);sKOOvFQz*iD9z7aR{8@u1*TF zEbG265TO0-T*Biq2TFfQjYB!cN5dP`-1saNDwjQ8M?yJKsuO==Pt7(_!IxAXGRCg_ zr)Lz8L>k9ffw;9gfBycaVb#Z*#4W2A;~4d%dvsQbldUu=Pua&NE~Lp344faNQz5Bq z4@jP48g9l`xjy2JA5W3DyYf7o_X|0Gnj*DRQgqXIjRHz0pWd@wvH4_a*1H#!X>nKT z2orqIM!z~&DuI7H&GI(8ZGJCr@Wg}oAI`qNl+zz4kr%Ovr-gD%$l_niW5$1CkmFC# zWt;y8@Q@JkDVn~S{(1nSFCctM)y9RWoXg3brThJd+lM=u! z?kTin!sp9o_d6P+;)c`FhWi9gO1M-Y&>d#T$iQl9`UmfE%%z)f5R#bNCC38u2;kub zT3AQBYqKD=cK=5T2wkMCf@=xE?p#Ep$JaX6_kiyl!rOz=9ned=^BXvi=#V)++S3EV zHYZIkq=2-Alj|89s|(Os1DUe{u2eWnUO8NV+Yf?{EANf)A^eU2od=GTk@Z>EG6-@4 z7&9b$SONFk_|flVxdxS!AJFzdt{e09>*paiu|S$ubj6B z2`qq3WEyZ_gA(W4$feL9c(;LkPQElU_3*Tg`9aDs1dGH7x{5L=4haiuyaMDhA}>Lw z1y?nLrM&>kW0jK&o#m${4nsLsPb@@zXRL+h+NK3S<-oDw&7d(P3Ns)&D3E|T1_9eZ zg+})^R0O@Ik}gw0Y;}8}nHYIg1{EIg*b^{1ckV1J50mc7L6qGzC{x*!RblAA{}~d= zM^-#TzMyEVp{5rvUT85gPu6?!qnK6LWuejy$lI=SU6iP~)NiuUXw#ZEg|HY9B|Lb* zXHGrTZ9SM)0QDYV&K|`uh{?j|*Kx)iCOJV@zP+>RB!gn5xLEhrErdFS3~UUHmHdx( z;q|Bj@Cul(U6ty>7r+u`)ir^BS13?|KHg@FkC>gVcAk`kf&k@^G$_M+($VODgvu2x z2D9|l9VjVcZui4b=_x#R!1RWX$u6sMg?aq4Q3|y4y$krGN5IAi*Y2LFDqR$%hbzg- ziVdzEYoNCxemvAD$d(O?O)Z!!C15JxZUR#1cG{toQOa+9dvKs2JuLd2FjTIRBqMK) zYG#%+br3%Gd;Su%=}TyscA+ycKm`esL_(?= zjCbIh15fd7_%-`n@etP>KChT<6DU*j^IL= z{qq>|FlCYkJ(K*(o2&I zlqKxxoj5Njc@A546EF7S%VMg7Nm2^YUttk<+rQ^8o$T#b1mzN`Ds#BKi*0t`i#C;+ zT-U>#JZBj)pGkj9VK~F#Q4UJW_8J%tcbFyFydA1H^y2xf6&L-@jbx#yh?@q|8ob*#A*1xF0|Mjeo7G zzJ#8z5Pw?D>8EaK+^gJ+MJ(KdEWzH*@tl%KM*eLDo;O-)MNO;f+#@!BEM%_BkH2=az$ZCYVh=-~u)eaW zq@FkZ?IP~6e0Z}gd8L-@B97NP)qPv3g!^hyMAXLDgi3|NqHV}bO2rL7eiyWQbYJX7 zm6CZouW{?H2t7_u%3X})@I572St-ZqcLdRMr?HZcuBMTKbdT=Iv@)N~S+8(d*ce^+16shaS}QO~`gun*$?vXb-SmA9W>(1DBGR<|dK-XFaKpF@ zO;n*&VwgDFr$02d38n~f5&eSVBs(kVOWV$tou65Z{l!8Ugf&ghK1Sm8gobpx%Pq)PSN?iI6#t&x>L_J(bx!9;BMjbtzn?R4f|La5YyL%&p zzkVGT&3z>%A^G2KmZz*?=K04^4Z?>HAA(D4C62yLuy&dC23dvSm!-aLDtzD~1IjlN zQd40t4QG=r9FWjO|2JF(v^DyhwqMB}YmTrz8M6*8x=!d)|!1qJ4ksiH>ae)*K zG0KNTEFy?bYXK2FTG@hZ8|r~_xT4VAYwA6Xf0@v?5`>zFxB;5MNJQL+Gtmuo(1U}m zaKw++I9){BT9^%7W{mFpQvey`bZX>^-_0!Ao}2~fqN!=tPytLTus~J?wg!)`U|yM8 z$VmdF1vtnwhM)zfp`j6#M`{lL#+?%9vtt*>qRy)%fSE8CW*FuIT`)LL6LaS{Z{k4- z4q;oc8l(7Yc$E`@r~V@MN1H*hzz};8VkFS|AH|2l8lVGO$N%Wp;3oJx-FAPnF%uX~ z$gBXgMmvt9Z6mn^jJ1(1yh@bm6n){^e;MZBOmqPupq_yNe)F^I6Qb?XCsoqlW7+xn z$)SRAa1fien-0c31G(azvl}q>sNf}0N!SO4XCLSm(}ZlI^zxH_)d%VQPWW|ktkbH3 z4sO0)F^N%(zxNRYmPlrk8yKLC3n&X~wi+l}>mL67vE%a*0G1RXkhw-$M6>BF;Rg?W zAobkX|J@yI`;%>xm#{5Eio55oY|UgBv3M*7UX$x%-oz~r9Z3rga>t%JXPwTQ%l~@O zUm}tdJ6*c_J?@BYt1PBHj+ur2Z=ZX-Z_b%_bV6sYd~Mlm3PNv zAF`s^e5Ia*2jkg(X%!}jr{dsz5ZR#S)GZia8R6XBf2@w98v8X-K>CfgivgE1U*zH; z%}sI+3U-_#S=Mhki3eKL?99{>oniHNK3!8~sT(=e$AokR@n~_lWisv2&{io?_={;w zG7T#K4*ed_N)d^DN-Vwk1ADS@L2vtJ{A1m58slCo=Gh6HW4u8T+$<_?xD zsaT%>E{3g%G($NtBnUImbmbXC5#JHO}s7u0FFR9(9a!5Z>5F?7A7j2Wc9qUP*8@;OZwSb))6sQf3r^{LuU%v5lox{wqG$pD9nG7Ehz6`xd$mch(qU4kf{jh ze^7W_gc0TCU+<~Ai_SW$)i__injTZ?*iLjEfhk%! zqq1cU@>3w_F>}S!WFh75gK21ntU=(UudfeFSTV4)iK0J1S^|E!4J^E1^rV!#jMU=* zKWC`C4$LCNr$tzLsz{W~0s#z?J3jvk1y7qKZ-=%pqpcx~CG9|342wega+^LF^Z^>; zWktn6{K*lo;J>Y7dAfiaSmdI;A588sE`d*fJYhP1=6{k%?W8L&PnG)|@ z?ceNTG5W6wog7Bhs%Z^sHy1U|W%|2PIon8{kappw?`D766Ciq9&0MGS9(RW6%rdjw z!=BO?-9~@dYx|$ACNBowlpCQY+Rn0Q55vFHxJq%Q+v{pX`4vXohX$cGL<#Flc&WA8 zZZ>lf))`LO7;IJ#ZGWLhsYe^xxh#VezSlI?moc@U`lmi{TYQqn@M2VCoZm0{32Sby z1;6#i9j3z1mnInMB|T4X(W5SQ9gp;H%!oFzO|)X-HcMMC?`bu~)Elp9`Ztyyu$ker zNhA_d<|CpE9wx(m(<953q7!=AP&(zft^J0E9GUd%mTHXwErCYczUGxM`MqV)iUnQ5 zv)aJnAFFV@9qn^G%LwtS`6M&)O}9zzxv@mSM_2~5exn9<;O%(w>z$CC=kkg zXlYj`OO)xkMP?oOE*R@*Y9ajJ3H@6tR%}V->lu0E7cfHNEl(w212|%BS+DtTC+7r{ zfFYRJ##>`#&VO-~9{4(t37;e>Nc~Tksk|zM?^TW01J!#*${|?oW_?pdI`pqKE(aSM|dK2)2EydZG&j2+r}EO>0m_FPfT0R z{LTEOS6XM6l1i(ALr=uPyyM^Xwh1|WF=iGj>tvc1JTB(5-aE?AyDeU^snujBjh9~P zp^C%_PT(20d&Bpnw5ds#kmZGXZ$pCw1g5#lEQ8mv4g!bm5N42&}xjDp&RISVICl3~dZku3>m+T=Kkx(xP9nWPcU;`nimB zT~Fld?GaB3UC66OJ`OE`wnriKt19PbudykKyDSd zDWN3fYxpaRqT->RI(f@r2&Z#M&+pu(CZ-`_NPS&EfpZ`Nko%EpVkJvBxU?O&>#u7) zGJu>HXlBFB!-EzQVASL_s(g&nJfU#A3P5#0Ph@hx){26;jYB%r`C#J&{eWu+7Xc1|g;e-YDEq_Uwh5A9qN)>)Xjol<02$P%Afk)| z7c4Fd9)IILD>4W`7B1}1QCJvSV=2u$2*7!cg@Re279@tT8pOW|IZsF+jW7%p#K4X8 z2(+u=FV0*Jg{uq5Sqdm`3xot{e+G_AJ7lPVbpag9SK;47qC?*7_~X$ zId+d2$o1a1ebqxMf!VzsE7tnxnlkevEc!2xnd`=z@0PI$;k^1*eRbQPs#h zfqanZ0-@ZfMYV7(k3MVa(cfWCy0wt3Qxyu6dl$D8==NFjEY7g`zYH`__+dMmHWFF| zxy1W%tyH%dqHRe&eyilsYZ$z|x7BrviWIM3eqB1~+R+OKsg!FbDFh)y&vEHbT2pFT z8E6{btws^XAFsU~Ghi0c*&u4#CZcc(0m;JLi{COtYVGeA89u(BB$Fh-)H524r0g%$HU-$TDYH{3^)V zwOEJl4<37j>|dlGx3d;d_vOP>8zLINin(Yvd_R(phg;i^B7K%+*Q0~OoB;FVXBEv)1L6d3 zi|Xg*r0n&WvyWM(({m~A%qa30C|{@BzM0s%c#Tn4jalHv>a{oXoOmhq5hr|SDeFJ# z8n@msQ6TDh&UntYz*F~fiGb6i2j%nsrv;$mY4{nQRv9Njl_?d&MqF`}N!f)>#LhWh zX>}7{8Uh0PdRhq$P1d)3}a@>`9U#1ipd>IIE4%_y! z?w{E?npCGO+L_JBwV>fvzFAfn6P4G=jsF8 zNF)kB2{q#LQQa7dvuY~K3x1>9a!pkxHf<_ls<64rP~l3F%#-c#`Tm={p|?0^OLYhK ztSQ8HCF=BPZ@(k0d^}(`HW(7QgYC!ruQm zq-K98-2$6ESq~(m|0*lH9`3ZN4*kZ%fEvy zl-LO?5jI%)!>J1Sn^w8MV22DFd@MvYYii1dIX0wx_Zd1Oub*otdzZgcC6#kUayWl~&NeKz7rVTU|MC z57A&tzj}blNM`fH@3Xj5upNh@1-2k4UIS?(rd>vf%KIRo8;);C^Mv2K)_mra1Lh^L zOvf>YyeV%Or}?a0`7iz!a1p2x0fE5is4`@jV8QLww)SvnavTC~;NnB~SD*hMcND`KzamDDgW;jv1?fJ8$;cQQYK@e7 z9v&di4ETt^(m*E_AOX90*h@iq!HE0!{_8;aNZ=8B;r5veS}+?z!D7zOae{Ca;or{{ zpdm)BB~sBsEummx!2(^8i*n3cFAc~$2cg0Y7*$yq$c7-46!e2O#;?oNnJ9fciGti&cEJe9cSKqUt~ zD^hiWW^p|_vvdUoRe|be0PqNaEfJJ|Qc_qL0ir(G%+bMCz#>40i!BB zvOugtR!~Bg#g!E+;91Dc%@D|qc)~qN_^Ty3ph*8S%lOSpvF?}BUN2Ehk2Tt!cFej$ z6`KAU)UHpIu?a$D0A>P%uLTCAkpV9ZT6n>>2=;IX{PWn!YkJ?qt-Ns)}qbfHY z;Wt8GRp^`1@<~n(H%B`HzU81k_zaU^cqd@53DR8n?nB)OA;PGW1f|3r{KMV}#%qo= zkU=lcI=#&lLzxOIC>tmWB@P!~fw%@5bFj@?19e00P{`Z)f8q4u`n0xE_rN>j1}m$w zhD~C`qTksWn4BbUhFP%@1dhPm7altJnV>`_=Sy7?&CcV6+73yp(MT1*Ah4ZGz5q5r8Np;$QC`+J^6x>^-Kzd3~%8nnA6pV+2vYRQ?Oc{_CDOODY-HXg3b zj~7oR=6E+=&Z}8p;GJ1t=A3mZy3lpbSb6a*hwX5y5Pswp2e;me(6qV*Ul|-W$7z#X z$~()IZ{&WcCzLHxVDY-dRvc4Fcs?u_+0cBV{-yo>>Zpqf)pLIw|E5S6xk;wP_U}ym zb~cPOG`|?zn+)0BnH6l=xxP&b<9eIODuux}enm6)Gg!!6ZmPIC)U|Mm`}W@~n~O45 zV?qiYIwbGp6|a<=T~>FYeWU4e_e#>=%zW0m)o8D(PL8LKcjB)cEbban{Cko5g5+aY z7?sZ}<{I{A;%bU|)^@9p$rKaA1{~uZL--!@B#v7xq{`sf?sx$`f8PiBufoE5@q0)yFueKe?3~ z7Z@5lv5-L6e7k@W6pXs7q+8p3vxfg4Q`Z5GW!wIdy+V@hAv=|smA!@R)sXB;$cXHn zJu5SNln8~8T_G|n$zIu8_V}My-}fEI-*LRh`}X#}J@<3n_kCUG`8$7uqq9*;sR|=D z6nN$LHlt&V-deY7oI8(m=T*7~9#!&m6`#7OTbs!rdM=$Sq$C=<`{US`=*>ep>$htA znPDPEnKW(FV-%T>Zv}LQa>lbw(C1N@j>~l=@TX`9MxA;vIPUqpZP%to=h~7w$ErTr zs$A5TmW6HX74EoS$(+JtQ8zgfd|k1oeSk)4AEMEvds*FI<4J?s!jXTyiQ5muCJ}6} zcf#7NfAiJKc>)hQ%`KRNJI$R5n>hN$lS?L3|Hu;!k?m*=I5rgAeQ%xSSC#AU)WcRi z(3tUdBDGj&_o7VMVb}-hA$0J<^C@f^Xj^ zgRSCnV4^1a&H_sWkg?*H;>bjIA zV5+%wo^!!Axra^ES|fHip-G+4vOy7K4y(ds?HA*Ir_15u#@}G8=Lj6+josvUMB_9+Ed)@i` zE5YO>wR@ot6XpbMXNd2WMCSe?)0P)3346W#_U;{{OIABtFM~YahfklH%2eCM_H#5aV00Vj|HjO`$g6poSL)~PJ$ zoPl$RQ)zt+VF=xDZ;4QTiQlux90L=sM5 zkcz>mJ0-sif~vr9DT<<*E_A|w-~q)~pu_=3o*z8;kjn!6-ZwJQ;YOLGei?7e;7&6V zYNf2aykMZfBfULxeL{MOjqiHbfh7suE@D*h-TV+cQv29>u7?#+&eOk|ssJoVh4!`w z&L7lxUiMihL4piut~%i?M_J2iXCJ{R(M7dNT#zcIL7Jgvc0}JI$#AnCqn-7L6 zqYyvE)buO-PcZ45g2T0C=8+C$i))IGe1c*S*?s_J;IQppyi6$Y4v|C_ToP0;6o}yf zCIN@1(K!Pv1Z5N)B5>TfxVi#7^4x0T?-Qi?FXscIB`9U0ziY#(ES=HK+6R3!Hw2Gn zg)&b91I`j`Dv^Ho)TT15i`x0T_rl0_h9mb%RA!p)jrUbcekM8yiU^sG(RrFF&# zH$X56oMN!NNCOEDRY(pDa52=6@FnBhrwd-M!2de~d6~p`!t9b1e{!?L`J;aR-9Rgt5s!p2f3*@D1A`Tb6t9m z9ffX#RtCysC-8j`@Z5d?QU+A)e*BmQViqtn(S%1IcoAOn*UaPxTmdeDl=>M+AhHH8 zuuoG?P&st2n1K@-gpm~fM|{a1r@Hq0{@aOlcfj585fKy}+)!g4w)!51S`^m9svyHF zMCkx_6v!PA?G?NiBuCGh5V46>iaxBF=Y#|LALYccd4eB?$H2<~WdpdOae4m%=`vbt z&@g0Y2SB-#&~7OPh&jW+QaEIO6W`TH$G5~C@Tequ}Mt!_? z?OGXM1PGN{ufcnPzH~u6azh3-=Ut!8d~k z8JK0~Q0;SkxQBw?kpc^_9uO!Yg$=UQLmLi=)~11Clm`fVy5wU)c=Nl&rbrQ*1>Sb5 zyw9>>$GcGS0dkqI&v|eLVd$e!ggc^zx~mn#qY9$))(q2Zr~g6gl!vb4*F$GW{pYYI zs)X|LdiWi%%x(H-)0bO}-m>Cyvr38;8E{kIX24?alur0_F6U`@LdWnqnhzV)@&!+{ ziI*_NrkrQ0`$X6BEPL*2hLwa)eY$;})q|riiu34(U=U61`j4T(M#JCaG#B=$45_o< z7PP%jQuU~OALE-(vTFIQ1jkWrceZBj=l$0C&I0ljm*X3Z=k4##<6KK>RhvwDihblQ z8S(pRvax)czASY|S9F(Tcc{(B1=R=bJ~w>Sby;U`p3X|tqddE%PBGWMli{M)rV z(H$?H%&%SJ@z}n(_XqTd2I-CF8{Y8SocvEh+DVV9Jp`K!a;Z$Gm3c}}<*U~#?h3NR z(@N|Z;$$!CZXHX94O&x;#VcRX?@{My?x>q9sG+p;pMKV(Zk2eGXsR$1=&;HW;)#~K zAC~*^Mu{HcINwy2OO5Q|JsbXzkX)6XOcp0cp=Wal_sAiS6%+R>u}SpSdH(Jw-x8<{ z0vb8_60Z{FXi^L%eu+H2DdG6VI&)H4tV--*vGGVij}mi0!O+ITZ8F(!FJ$q@cz9P9 zV(ZRXsjdjlEj&$dzo>9R8=Us8&OM)d`Wykje9;9Zqfza?XA6}w4%g3JVniR;gFVzCZ6G7sUxA?3y3yW*F1?xZn6fq zo%1d<$E#3PPQux7e>Ukru)JN6wb?HA>eLsfgX?{=4&|y$f`rty4>k&PV(;-&Z1h-| zZK&pMFKjuqQ{0?xmQ+x>uwkK$SI+XC4R7#5-0-V}hy0Bf{Oycpg5PeOKVW|^PsfVe zg!j-(PuTM~C!jLp6UA%$%o5Tzr?5nfa~teVhWL~wK4IStnJlWqOtUALi!*6>qSI3N zmlmi_uAX}t#%qX^c+;*?RK48Wf}BAW`vH#S_#riqV~jz2Lto6`eb$2+e0`v zYg-Ar@6pN1|2lig=|N0l1Hbfq%;Tz|Z-8+l1HD%8*82vn5Q4FxIYJbTAI!+O2P6&5 z@%E|Tyhf^>xnR19+BTs1MFV!x-Cp7i;C3~j-*nKD0l)$WfFQ^(fuN3@gG_2%cxp<3 zdy4jR7+63O#91+Ae#-;uC}otM05S_W6hOcM+p#iX6C!iKjNzng=Z}t|fzjm2?6Bw9 zA$0C4+|;geM;T=T-{ytV;v%7IDIb6wX^@2)}kgdJ(mRj!fd zCKM_TOcTUpMMGBP`3Vz9K)PYpsYuuaa@0M@R$r)`hcX0FUV)^9RC=6RX{s>7ZMqYS zEJ%R7g1!Z4Og*c#4;$;xp&SaJQF<-kH?sisCc+L-If0VNC09P=m`%5>L*NmB@5nI} z?f2kz(u|y*0C$2|PY@MA@|WtR7s&M#&42-p0O937!;HD?2&k99KO`QaWuk!e1V(G; zP=E!DKlb`w)T1%WFJSCs+9pA;VleFK0Z=49{J5?R8WOoIZA?PD<6?0YvIfyHGz`hN zL;3;0z>)!`-}vU>lK{awncT#H!Grh`s9j&dZcJH4rL|CR6^^{x+W{KWak-%r)Spn$=7Dqxye1OB5VHejRRG&Tk~>I_x@V`xKq>-S zX;3XfN%8Hp&)O-C{#-a837FmeK#y94oYNe<7jGbW3al+)YT$MawPG^zOHpJK$aj%% zh2jnl#&^!MTEIshSJr^>JCIp{h67g#4&;co>M!f}A6gNrkJ0o8>7Nl#8Om|sRIiY- zBnk4-lA^k{4ALvX*Ygt?ym}xw1L_SRZ=sPp63_We z^|Wxl}l%)}qDz2J1!^{0>iPM3b4c%nYA%Q(@!2yZdS$fbdGkU6+WF)>J9jB8~MCr%5Ho&sVO7~3L0nvtj2 zn*F~Td}SOyuH)OJIpG711DwMVf;jL>n1-(d6u_vzfqOM6(Z{rh+y%9y_bjWF`w&SF z${}!kpPdM0`2%iicc6SiDlCMexkzXN3JpO}vIDLt1Q2BLV?gScmCl%j9pkW+9;SiuYb04RgNTY=XK%@F_ViR`tYmob^<+||TS3ft`P(~LOo zXu^%Pe?Pl(0*5c+If%4&bdIZ^9tmPFwH31UdEDt?xVeo+S>qYGTh~4Zo&HWGX&*Pp z^TlxIhrq8FyxiDBrCP&Rl>g>jR;i`ey&0Iu!9EgWNtuG#7m)F{tDG<;^#5^j-Z+_a zPnKE7ZI}MuO~6Ec49EWZ%Y)C>s6mr=$gAZkva3zqP#0*zjZq$=qYC0&`|(Sug@C3E zUwh&7iS8fS5$oT0*%zn}6TTlymZdxM`hTsLt0*PQTPUOJNL%D#^*7D`J!RQfKoHp4 zuC4}=@OsLudoIiuNFb~ngu+%#;3zj!_bWZLjaoFvMv^8{Vrp-jpX*OVK(JpWNf;ZJbD4Tby<#o6+) zzhlpfiHDnB=1<{!l5r7}bEIb#hgYo8B%^l9{;9qJw*T*f0(ra~O*Vc1m!k9g-uOir zZYi0kzS*)rhYW348<+f6S_RoNI?VI&b1ih2xwA%JZ0_EZ=jI4FBv0>Lln@llObt#y zes$o_XB^OXp?SC9S2TMTaVNRjBiHAjrc8$9vc%@!>VGeEOm|SX8jFg!TBX?c-b_$1jW4${_v)v9PaOL^e9k#`x+=-#ZC9*E(fD?l?u9uI zwp01+ZiW;mVZ5ewlgegRW+HPQ5(?WD_0AgF!g7w$r_CxJx!V_(rIQh2PAz~X&`*DI zRr3KRo+Xjk&1F-m!D8{qO4hVz-d(R(LIv9**=vNK`i6(Rx6ELcT^-wWkx5q?bNJoG zf9m^F`Q7jJsuj{Wdit@iy}!AIwfRThDM)`kK&YYL>#0#*OZ?QI;i5q;epvkg+tnN6 zZ@&|!W$U^R_{6F)m8UMv-g5Up&Cxf|bxF0jyZ6XA-xv3sA71bX=aYA8mSHSoozkt3 zC15gi-Fra~fP^<@0zUxe00qpBHUKaiIWcu)9X)rv1~KjX0_FZP2+sS$E?6#iDiV`cLEhQggfwy)9AXA3F-z(H&RluI-TH$ehUfRNN} z!X*h32P|zte4GSF49t^o9t3UzPkIlTuR@tMEK&h| zJ32D~sY2sGfEWMyceI0v1yUcI9sC{u9%0E#JREEQqS}#_9Josu!$b~@_2Vz7fNvI9 z*j`4HDUf6xF}^~BI!J0k9yTy*92kGuF@y9yyCAil&OgDHq`CvR9V{qve&e_4O7n45 zLDUxToA`<3i;XSw-Md&WT zxb|7?vBc(Gs8Ei6UPM%rZD&Lox_2SpfRWgCMf}O!8u;gdvpEcE??QP48flpc=Zu52#iuR^ zlmsXi;g=vYhguoxOjOw;QWi2+xeJfc!y!?K3_wwFy7z&fWlU=^cDF8DJP? zbkhkB8Uzpl)`~O)QJ9kCZiQA)cOzk ze&jBJG)7PPh%jCOhqZW*hDa@Q-)gqBaXMEX^S+i%p`uHuEZLSM^wg;R6xXVTt-}-N z{VIarAD7!Fc#Vg?{t4NuYmZhJrpE-zVz)mseVrK`*ikpc2z$Pu6$rU5WR1h)`7-F- zLGL~Ntt$zg`vFn|EnAP{O=W9KXPX=eGjINukmR?ovKr9QAtt^vNM1^)!0?H=zfaS4N&6IWmY(dsVUD=sogc-{-do6x#h(*~w0D&tACLzIuhf zmXABm?OSSgMNFdK?fthGy}db_G1sQY*PR@H)|5zd8Sb?$rVINera!;L-w|q9a@Qqw z`f*ZQfR&2cZbZpNe`aix2TZxEzqRUe<|b^Fj8<##aO?&qdqQHB?wY;5hEZSM%YDu( z>?IoQP%b~Y&bTa^x2~zPE?SsGq-I|0pQpy%RIDkFf3KeA{FAbbtq71)t$fyybxa`+ zNoW4n`KyKt-_p8Fg8gg;iQe(6r#CKKBbLm*d{QIVf1>t_ytj8?^K&^LaC_r)LF_u( zy*5pK{J_D108DlIMqZJ+?Pmi#+jzxVvVsQCRjeUnv77!jT(_PkbbV=;gAEFZ2*4hz z*#R8aQcz7WIT!G1aL<|A-23&>LkM))`zf5E6V$=XAXHLNBhOQ(7}j@!dZntg=^O$Zg62gD7anLB9^YzKR8I#2sY^8 z0mTFBGzDT|?ESZ%3L}eK{nwE}34HC%jgHHwDqoUD2JrZhi{T#u5q0rx!%F_WqT=FZ z8^$Y|39zZxXr)3~1>=zMe42pc8e-BE1{i~61J?o7K#LoW-4KJFtnpcBIpZ1jdiO<_f_Hz&p9w>a?~>vw`VH!{M^0Op1=k6O6`z`fiD#r zF>Agq83IrUe7S<6qK+TIcf)7D$nGKIbKt{Bx-B7BxQCwM{UavN|G9=x69YCiRNSO{ z(}9R~>v%%@>bWL2T|*-yPjZUW=qDzI@Q4Z1r?p4jAgJn?f2D0a{t5r})sEmMyl+ev z{iTo1HNG&Mi;%PP2+94mH_^kD^F5Ct86zz4{$@)q7X7E*%8vPrtZ8~Ol?4X^oh|40 zvyH!IuBJHgKDTeI?!zVa&SK?tJ3bLu%4VtXYMBdp@8^}_RbMZ^t&&5~<*m|}S{9j* zz^wkFlGPN;zrJ2F-s`i3ukh4rCznD%v{Z)I-8lg%8gHgoW|qB9Dy$j8rGMM9ePtqD zN`#bK1qnLqeYo!|7WtVoMd!36Xx=~%VEj6 zx+%xo;#uTloXP%egFfod4_GneRoDHSZ*bmvY!%`*%6J~^J@OO{e#o(G>Z%X;iT91A zPrlZQ-pYz*BJL<3W)Wi~68K)NrE|N$*absD>n2+2tU2ot9T}d@s3Tnfw!k0k6r=wz zsz~<>>rmrq_J>HdiZj!)H*AWk9IqFitow+`GV*ma3u{6yk?vQ@*7ja0B^`Xa!*ExJ z!iLiyZOWL%kFLZs1BrHE{ZE@{Y}_CG@fB-kH3Ft|*KkaXa7=XKT`O*H$E`f$*7O6$ z`>D~wV=-q{N+t=pW_O}}PIF9YJGVPgM+!4{xu(;f)5?gU7T~83Ke(FFU8^J1^0(T# zOMCIfJL$73Lu8L%+NGTDD)8SpON{4ot95p>dV}d3{qyeGk+#N(KOx{)yHb~%2s*;Y zb&S!y(Fr*jDyFBySEsg^e$Dsk8@LWhly;{JFcHnB`A4^$*LlMbJ+%S6;qhx``&5vh z3#bXg8xe9M%IkEMRx_0`tx|kGBhA$+w~T3biSMV=9aXMeXO>$72T~5931ayZ*G)6$ zjge`qI!~=l|LMck80Gy~7TXVuE2}vpCf@Hu_-2h8#$+Dox2q0F(a@>%bGFJlzgLc% zAeG3e{<$8~zoU|%Y{yq4`dun!`&Z77obyplsvF{hb}4S3nul8Zer9#KZdtvtr;lJ% z5E%ZzX!kKqfW9s6OU{Sxv1`#4-RaJZDxYMd-+v*0JQ_Dm5-cCM@P~1p(ZGN}sX8ji zxvQ*A^tY3Q;SuA*l5+*64C*Bo&*qJPuZSJ{fMGT=%l%+?g1Vrb7-w9xZZ` z1+ERoF8ax^|d!*~cKuiJTj z^qc>FM$JZtmi3tb{jkvJ^Y_O`kL6Ne|?Ue5lXU|)h=NIo>p6qI| zz=BiVOZqxz!d|A*I#j&7NUWTj2pFR89e_8XS6lyd=DIdC1~ifXt%%pw+G5`6P zOYt{jh@N}6{B_2{r4=WUvS&IHXJ(#?qzDKPCq|DpIipx3oIYA=YV8CUpTk15%23M_ zLdye(*!g4)Dp;uEx$6I{5d7xcoEh#&%YT2S6u>eN6Q}B48zuF>@A&WUYizRr`?a*- zr>S+P%jJs`|Nm#*{By+TzyGb^d%z#Cro_VJC&9&VbO9>l;Qj2%#8WqojlM+dsjk{p z&$>(3cmOu-7ar?BK(Eoq7o=`}s<9?vuBvw{&K@zwt4^JY?fG|MQ|BXv8yD!;I_(qM zB5}LBT=c;Z8lZ5Pr30?&a_b~hX=6@7-D8+RLxxn4{Bx+mO7Hl4*YWPq^r<|}^CLHq zbu~6(3jj$^JVk2X+=S*}XVnH8uaXimFa19Yy{Pg6#1?RhR;${RxGxs(yQ8M^Vxrkq zuB0MbE5M`@=|0_QXP#RD+VUSX{n87Dfkoe68LKQ;JfyB3$IY9aQVi(9ov8LX{LA}v z@;vK-ssAwcsqffE0&>x%qr}*|Uv%!BvBc(okgOwzOUGv<(04^pr?gw;)FDP$9fuQ` zO{0&aadu^Q=f^}_%>Rz)`O$W7pZ+xO+dJ^mA}&gLF#1ngZ`;gPhL9Gvverk=gD-|v zR*H9G5@Ob4j0D5CRH@q^vrF}r{cW2pm5y}jGIm^18T-Zj>NI|(n&;}Y^Fh;6tq)c$ z))ryS_tms|SGipe2a8_PSQ^5->s=Gm7jB<_Q*Ytj%3l!e`X>4;kzkBeWwh6^MQr!f zk;zUprJd}WV$>R6!%(~qOPfy=!@Qjp! zd)jV^RMgk6GTQ(KbQN?m;H3&v1W3HWf@auvI-?!l+ZgX93p5>A8mWB&o@1n%{n4o& zYL;KKqBSKs+*3Q0ym>(W-bB+cAdon(jmsf;3}6in zZfg1$i!s(|(FD1PpfQ1W0Sm?&S07rzD_A;zLc=DvPNLpLLLvRbLGn`gm0g!y%t47i z;pq>mc>%h-A)~K&ALL}#cQe*f8U}Rw(;KaQGXTrNARO!8*`+Pbwjz|htTap{LDakS zXW!%FiLmk4TIwg;b^mMRnM3}oHzcaxq#fovxPIn zgxnt`$(LE!fo5qdLVT}6mZs76U2F~GAYu9`5jJfu9>TMch9nmxH^f`{OljWh`nzQZ z7-j*i-Fx4Bsj6ou=_%_%`PZx5uN!4<$Q5g^2cM066V_{P{?Wpm?vdB)pQ@NsrgZ_o z3&$_$HN}IQl6OIHul*o-g?|zzs8Q6Q$*3qZhbI zS-tm5l*;&y?@yywra|)VT6Tl|-Gz-hCo1~wLv8i6J@emM;n+2o$SA~XHBHk4sPs$w zIF%zl{jy*c5idV8fHQGsm*MOhQKuHkJI|ytdOX1)QPDrIZ1ZqFzrB-xF_MNPXHY(6 zN?26-tE#pIPK7-7XCF?r*7kk5`1l~}eV!A=UoUg0ZaC<-*$b=I>0{Y1)rQ04?&6{i zuCOTa;U8Hn1_;Y;tE1ll99%tKX;J=Pk^zi-z}py&E@v}*Um|d3V`JmA>BlMNR#}lz zWuuOA*v*FBr=rk=N#y=dg!@tw2S=x#R&3!1u@}n2!W{?MfEA3H!9Mx4ad{TnDgjy& z0RxD9U^m$chg+GzWBAhJ4;{)rj7dAAsIF#X!zHud8-v0op$(C|cvi_AJng_GXyb#R zATtpU2h8BnG&RljC&Vz$S3>+EI@+PZ3L9JX?H>9Lu#_*bu^d69^*ul$0E2DxF$#1#5|a zQB`O>hE@`7YA{3e>B)|gj5gl~;o1xw+%6dm@`O-IBezC?jae)kMKu$ku>_uy3ZTjE zTfHi94ttL7JZ^^$_xP{b$uj^C*P(%Q2F@nnxwTqqJa+c>Gl!MIa1CC2?_$}FMR{4OBt&G z&3XMTGMhm%y6%@!UbB*8S#RDfeMtNPC#-essZYnnin#21>G*pz{YT#QgQbt}pJRG! z8&Dp>7Npg?FMjPNE-g_+pxJf#5!%QlDt=7)r|az|*a2L756;(=lFbI+!k4;krS`U$ z^7ASkC%3fRVLT?T|5T-bFqYwE+GWm>m_Mpg@d|d-M_s+ba))}XYz6zhuc?E^+`^Ma z;yeG~q*vv|wJg7>k|iSkdWr1J%c<6-XTrX(2)=9{HRC?PJn_$ncMjNnZSdxy+N^?| zpu>V2dkclPZG~!c#8Ki+Ms|K31@2gPVU_OwyRLV$Bed;2d z%*f1Elv6J)kM$kOJjf!gUdGREJ7(3?&dWU9q-+-QE_{ZQ<)4h}-5}Tfns1&APjOIe zD+{+sfvRwT`(&2FUzw#ZXw+uIWi+EQJ?d2S3;X!K6D|!GIV|7w0K!50csp+4WEZI< z)><`{NY^tFxa6|!=K>PbDCLW2xU=fTe_6)=j^*0ZH#2^-dC_h#I%oHXl zj|v4O7Zm<9+P<(Hy%g)t7cnZ{b+n@+AE1r5Ze`9!a)`-X?5@idU1a$<#m(0yO2=;Y zZE%W0(SZ#ZFlE71oMlIHFNOflVMJIYG26v~e~nyAv8FcGAtMcFE_fs~r4SE}6js2z zgfozFv!ay*HZDcBGzc&wWJRo8a2-E!n#ZyOfE=b$d?$cRtAWVFr-8+wg2ftM90*@w zhE$%Z*S@V00Nw=Pk)Q$72sKN1HYO;uqT=kWftVYo+8rVI9Wjq#!f~Zg6HztUd_q?wm$3ClG8=q>@USPt3m0ACHmhmc>vToa&0ri=7fgxNxprl(y5|T z-;;1cf@DuT+i`k`*DuH?9_bO9d1p9%!zq?m%AdDcW2tyh^T})_=P-d1r=ObI7b29k zU&~%Nr~iqKf_kf@<9zn^il{Yfoi9+t*glb~)zIJhmT+UWmo~#+^B#%bHHZQpc5T zQdfPC-*eBa(2Cfa>d`6-iUv+aFQwRV{U0xVC2VgPsN&?E`u&|X%#1mA z@@wp}nC{s(m%<4B3Qtv-YEwq_@bpbrF7VTw_c;Bjo;LR|oZmDUO9|7TnMfo*gH!x9 zdRy)4^NIs4zaP9j4Xhe_&sY^1si`}I6v_NeiOb5d)b+DASA>;k`L#kUKI7<6e7a0R zbvu)FGxiB<%%C`ix24?owe%T|kKaS7%4N3mh}bS^7g2k0y)MinzUA9_Cl_DkQ9BOq zJz@2?gjohekA^fhzTxkFY1-4l8lw*WdR>`NV}JtNpZ`H7Fq?2Gd^n+#

R^zrLV4LOoq$F463$Fdybi}fz?>%@MZyI;$(OaG-i z^~cD_H{kXo8}l+_3YZ+TOt62f@XK59fn2-@J)L)H4_|_KR4iL#zkz=|QL`XMGQUfQdd8ba_y(hz-r<;k`Ms1cF!F|6OrUO~Isn>aZ7HsB@a|0W z-X4TZps`Fqm|vznfmxr(|zNp{g z_$Pq`VwTsiQkKMhvhnD46-riIaQ8y)k2Wto@H@)eW;eaoBg`NziP0&_)nz--%YoAPQ)uHi61Um(d-XZ-*>D0T-|5BBxN&*##Q88S(>ziM0ElWE%)Y5Z`E5ktIYH z-e3he4OJ6Z3`68~{IyFc+FYUX$bE{%T-w9kE0NYCm>P#yJWax8{BahJa?@lFi*8{I zFL#Q3`vzu0iN5;r#hT+JW%V$FSmO$dh+}Py;CL&!fOjHqnAKYAOsKktlzysrlj&^F z^B!W|%N5zeN6fuy%t^G;Y(|3Rn84GFbJ)ejCkH|ZyL8Lq<4XwMxxD?#My57jjrc#mAC`C*6ao zyIo*mMW55(Xq2O~XCm3m?UkZD%_!3@Qq9``oGHpnS}$uWoT6U&?7Lfi_2n<7BgDVQ zW=@6^mR-XMvM&78P1093*2t@BX-ZXlRrWy~3kOd1Mf}j?q%XDtJaz0Eu(|r%@p5#W znyaEWR-n_@j3tsRP9gQ@HeVCKw{;5s`c@_E<%CL5?(zrvS8ANN^xlcLndM(m&|Zx1 z^3>a^Pg05EPP|3WR>>#pSg!s+%@3~sxpdGeHG)Bl9m?K%ahaYnBdim~XMOMFZJ*%? zv5Ie<{EY|q0OMkv&66){n%Ym(=`foiT3Me<&Q0&UO7!Lp(j|U!=KWY4=3SlNTiv|L z1;RKX@1WTOZehGRAB@nSkrk|RJ>5qNvGGCygc-Rzs~%5QIx$mQ_iKM#+!Z@$L* zcF=mh2ZFr%Wx_LQiwZRCbwwmD(y>i#w7cpCuOB|8Oe><29Js!%DI;}#8h>oy)n~oA z`<{1_YFM?%60$2AUsg- zBQtuOcU^iS*U$O4)nNJ4Mv7T6TiM_vn`HVWeHN0gmDe!vdhW6LCr&+Y7lLs+fn5@9 zWddWqHTYE;=wTln{T+A*sqj9>qu-L4|HlQ`%-A8h|3Wu=vF3)PV!trdKt#AWfuvvZ z?czIThELOQ@r?=exw*J}Xf-r2fMyby!5W-_-LniEDZ1*=9p&3m5eE;)khbsxu)6IT zVznN~a@Z(k^Wo~p5cAlM8AAL(;LS>B6y)Tbg8o#IEio~Xv{v#*yYpy&f1s5KFe$(j z3M^3oNd*$aDofpzx_X#7Ez~cMjqXBBT~QH9DU%c!OZ|%BYkl=hJ+IAD5b&1C_1#$S zd$jm_A8(n#)DVWc$gwb4BY&(!6=3h_*7sH>E77v>kth?dEN6eMYj$(xAvDa7nkUvu zl|%gPH}<@MeM&Z%?|+okZVBq{6@p5OEisxj0Rp({z~43|G(b~2rN%#H=`ni zJ`4Otl_U7E1JwvcG6Qvc((HJ~Y@g`{B%6pax}|f}l+^Obq*Ys`Z~_MT5WC-2Sc3aj zpwJi)r~r9_5P_@AEZ;c`mcz3@%x?iLxxAdpzyLVeP$i!s#Dz&SLxK)U_5~^h5K%_P zqyR>S<6@_<>+wE2n9X*~+)P%c=TeCQmPflKmDmPjP|N*HJw5Od03Shuc`%ZkM3hq? zcDeZOK+)z1S#s4oEDmWNL$Ky{DJ4#BA3R|m0z@xeb?EXr5YuT8123(C~LTelPeY>1pF`)adbA$U;_KQM1 zv^5UA6af5a2fvi|+09{KsiXTM(KQb-h$_+ z|JgZ!d3;xw7T|S2%=Np)R|9r3?pL+=Xc)>@8AqJi&I8~Nc}^;ShvBJ7gUsqQ$X=d) zB#!t*`+pZhBio#}xSy$4+w=NV(Q)fZXRvHYEvR zLqRycxd*i{oU}VTa8L$jv!`F-pdum;ZjN9zGBsaDKUWN)KD2 zVkynFIy?JrkC)@mQIwQ(L^^k}b^XLepYo-`-?;NHHluIwIjnmG%lkvVBYaz0QZNZn z*65wP+3yizsIMQ%edE##=ZsaD*p37B<_iDTrWYT;vl8iU2jeKYwBh7v)jy|bcIr)s z*0*x7UAS9jEF*PJo+9G&hrK_k2lJ!ZN)ZtBxS`WNKle~DnbB~EF>RI1W+~)#(PU>; z)=QSpizjg}Gc1}sY@9})rE8v1c|t=YBXi^pwI%1sr%1}qR>V_IRXVV*n;z|@6p;HfkCQGZ z_tz-f)+(UUi);_;Xu+f{fL(0py5q}`H4>`pvpuO7KaD?bvREZJby&fBR#cI&%dqUp z#quGI$aI>r3r7kuI%6p}pC~1@wLUoSu)>rutenu+>Sgso>UbNOit87G zu5V@~5^V+2ELT~zHM~;L=y7$XFTP13DI;qe64S_N??^0oCQ(;z=`~ipdhnri!j>K$ zU+);DIa8v(FRT30lus@zk8o;k#&;s$&X|)==C|ve;~S|Y<>j{JyT{76hSQab&&&}8 zMxUfA7wL!b{hdfH%uXbJQ+d4e3~w!?kV?EIhMvZ=ibk3{&)d* zY|=rWGDX(ZVkqwNW_%+l%h?x-6y|C^wQ2(ykBlk9wDD;T`?bsY2*Yzrhq7n#af6vS z)iq*Aiu!-2WN+6B1UTzA_n%^)6z3~?d%~y9XWrd9&uge2Z<@eWrxm*H=%b3kYEouV zRhwXnc)rS!rOO6{E$qN;$tSP&O%r=aqt-NJ!df4NGO1Fv4N~5&Z|5~NOUHO;DEZWq z%aJo2a}=8UH^?pSNapm*k2y@r=zR7q4>7zHi$DFCZn}z|JCC?Yjo$Ef{Mda=RCZYc zMp*7nb~shV;tVa5b;Ebt>BfP~iHz*bPk4s(#H&ej1f+ycyT(MOufpG!7&~Q^jx{dD z;TGKAEhyv`@b6SBcUYJ^C~mp_ffQG(uZP#b@@=p_ZaqXuypq;k+?SNarAW%n3c?Ex zFpI@mr}icUG9D0#_W6OgtVr> zwFO`ZRJ=?%`^K@l+1$jC7ncVK1OOj}+7DGa&F8&9Q7Tx->VKs38$=lZtbpD9x9iM=xCl&wzH6nV7WiBsc^O_+;7@4B zn(&XbjA-U%E(q14lo=>}mBD7`SR6#Shq~ zF~5^%-p@l4jWt{)IKII&4XK;}=$ncSVeBXerzGT`geYBUj?F}>JcA&OzkYQcS31KX zmRWq>3e-o}UR;8BVr+;)0zDRTPHnp5sBKygRlrr|)fYI6bnE;M?AhZU5aTx6*r3^3wrFW{k1%Ss6 z-wWV(6hP(F9}J1#kYxbyfCc2^0qM;jP_wGcZSAS)+-iu!dpuj4&5fQ2*TtW`V8Bf; znQk)yv~}?JhWiElI+%DMX;!gcmey0fsrDA}C4%{Y?gTI5v z99$QAF3X(2E3z-{uF-D^Kp-&IoYy~}69P#&j3kEdBPAayb3jAb3wed`oF5cBK#3lv zI82xYce57|((5)ruS0}yV-On16QY*~GOP!8I*BIdKEqw6kS~RM{t|K>;H3|_I{Hrz zBnIR_AmXFY3ZQRCfVnkfl=nN-xhzaWF^;f8c;6u+65fh`PzIC4^v~D8j|RvS{*+&! zABO_%rsr}6jCRW)aY;0<7l0XL84IuW>bZH|yj(EBL4Fsb9<$ZmQE0vlQ=#4wRemZ! zXVEwW#<-&P-P&CB#WFm0Pfsl?S zLf~2nuf_Hai(m@Gs$4#RVYGlaLE4rd@&^XrFmS>^bW)G;kMGchz-LfIDT84ray3M| z9>izKt>Qx4ar06u&?u9^AOU>eVJ}3~6jH%AlK3kq*c)<3k4H>oQXx^^-kz!I+d1$8 z%e(f1{7I`8XTnFAowejYPSEWGg&}}sNCm;r-V16G5S)O+0PN{IH(7&ft!YD?H0Jq8 zf^lF!M=6}>c7VM+Ob;)AT*Ef z$VZ_eNKpxz8Bivu@_>gZ;MK5o!e~_V!%}$%vK0e1=|`}WLULSWa{@9>_%`6IGUmCW zy0hjkL-&GKJPX_0oZ(#h4!{4c(^wdKNsG z1-yOhGeQVPxXgZ}kdHLm5^EA;#;MWZT6D zM{`5Z>^a#5zZ;d_z@|J)>!~RzL4Y^0SVSh6@?L~WQ5BN_wOL!8UFLV9HUC&yqUo_8 zM3;Xq9oWjsfkFIWh&kq@L~Vs*M9Z5~E|Iu)t%~SgUZZb-^0sDO2CLq-s`}e7?d>oO z8#})wNAnBLGQok%DOj?+Ze^RDKd*+VJfi-s(ageN{y@Y(xm>j&Ye=i$uY`JC>EpB_ zoUcu`6;)#-d6(axV;wIp*6#@ya^%ZY-J!W2LBYDz}! zvr%5Q;!?mWPHbKLTrjqzsmd7;$n|Whl3#mA@rcoF{hFu!k4fS_oSfDGj9a+K=but_ z{p=N&R3`jRdq2+T&vd9KQJ}cNjB$I|_+08i!M30tO|mHQ)^O_QxogQBLh(J5F!l)2 zVRMsFyViFQua%|7yl^_T!lUa9MNbXZ@a)zv)BCdqI9n)@L}#o93khgoK|M zSqwN+ZI=P=Klz&QH*RoeheS4xJPb-%JvE{d+27^KcAs>^4r8K|v3w#VR8X}%lUQ7O zO+{9*DZ~$h$ef} zS7Qe0+^}=9Iad#Ay-My6M!cI%OLj<4KJa;Y@RCW6{K~c09GDp|$&z|QHiiWDK*8;o z)}0S2As0ZOLxNjwB$aY9G+stCWJM+9_zrAg>froWQc_|-^BihUsFaXY7EmXsR`U%{ zlLTN;jx%CGqxdT|H37lVhTTb$hqcRB;B?RCriy8$yXwwAaMPKLIR^F&p!cAC0_p7{ zX$1nnNblyQbFfi;1X0t!;sz`W)~8zEL5c(_uL1bB9DZN~X1x!G#Zb!TgDVUm7W-hH z>kXVoe)yMEVL`oyYVkjj2EEUN;v+esRnkQxFo1JZ~i z*O`>p^JF0{=30`=!bnT!v_D^vkdhM49S&H$Gn7qigg9`{0XS$C$tb>&sKIJw0wc<2 z$ZO}Pi4D;D9ULK%q_4`sOi)OO%+YnaE$%AJ^Z;w2JH7uTWDez81CK6hNu%zGB@h^q zOFYQ1?pBzBSNcEN$nM$%a6`@ym4fYG1LXI8`M^a*Bz>?XgqRmOuqc{f))z z#3m;LB|i)DE1@V@htbFy?Bf%gds?Zql)^JT0Zf|FZR6Jgrd243 zdIo}*QM8Bq)t9~pe#JZUNHm6jnz<7!ulHbN(GHyknf*Prc$6zS_p>-P&dLZ9iXjXN z+2wh$YE{it=Rx)QWESFwk+sqs?3eS!SYd00?vzj@g)s0UGej;WRP;^Hd!83?N4 z8yKL157p>MoaVj1CGh9Ax3?60`gsnH&zNgqdC|u4;wD6seV$sU9e4w4GW_A^d-mb` z_wQFPye^^zGi5L!&&|!H#r<-8bl?FL9FuZOBJdM{5t@M4uH#Dm!9!{#fuHdMi&kLN zsBB%suC_E>odA^?{lQD}$PuUjTvkWZBO^&Wt%c%LgFqyS#^i}SdH@rC8QlYZQ3!gV z)Z$6UH&D_mSca0)^811LR+<(UI2{=C(V%J4)!UWo`Wk=s>;j)?sx?A3Qkq^oV2n2r3g=X)iG_e&PZfBY=mjOcU?_)$NjfrzcY`l36MRT4-p_qq)N&w-r4us7pup zY3g_i`3wNZ2<*QC;hMoQLxGP*$WlfnlGeO45S$)jxn81+2v!HrNEPlzz@5P#q5ZZ6n}oG>N0fKd9cfT@gfVBN0#L9 zz)K(RT0r~)sh}l9x{}fhKSd=nc%2HmuiYsxF9%m#RFear4e54a*h*S{8B9%!Z-}Bx zuW?(xw{;6TY$#|Yu&@4*dd}=Vjve?B*BBTjP^bR$orS-P{ON@n&mA@}NK^%_OhQ~- z6ZFvYgYUSY0r>K<1hyy`d02zy*V4uYaCFqHLMSmb^3P%Fns9jN^(BVoI&`Eal@I8^ zud8|91zljiMshPWrjec~djn|-(3=J}TXHxsIw^%agRpT3!Fz;5x89_kDkKyhra5zvuTn&wbz5bzbMdl0*kUXm})I z$x6ccRub-KSlE6*9wl(oEiNfRxEt2cCp<3&*Nmd1yDmWw zpx&+yUYg#Pm+<{4HXFR-guF|f=aHKA%pOVhL`e+Drx}8#le5=w+K~5)7!xAOSoVMI z;Sxepv6m1=ATf~}4P8D;mf6E&lF%WWnx1&eJJklS+D|Y9OkBA2?G(7N{5W=CWty~z z=FZ~!8Emb`Lzg(nI%kc&Yaqv_okDC@uqacrvt8dx4h;(lEe(At zqrUiY_kQ+6;^~$YsT-_9Q)vtKmNR=gw{-D4#g^&E%%_<~h?j5B;4VJqAhcwVYx+r0 zQJXp`Gjz04Sfg){%b+46&*MEev#IoQda{S0SV8~2{U&>uer^?=J-UmZPOGjh_O_8= z7IjoXjGM(1zv+>=2s+dw_l0%7ml^)LsSsB#nDzL}MiII@2bf3G`5%ho|3?UQFV>vo z*~ltf?pVIYp;f8=T`iu$gQJYjfNz$iMk;wkI1>D(H)BQ+X6cZzdnCUG?m#9`&5-D$~z!tc`nBEO^(j=<<_U zwki&5wO5W4hKhoNf^6?a+NQj3*ga=QX_m6Zw?CL;fT=&cqDAA+08=qv#h&KqS zb{e#L=j&~ytk+j(dQP{QCH{DGo~NcCy}9a^S9-K6S{xSCy+ZcY^a>3&KK^?hJa-nh zTJ)EFmJ}YA`*48CE!1^)rca`Xmc zPQj( z#E!Y{jS)<=C-G_Xjcbn=>~i&lQIzo>ZSWTWY6^Z?o^h$V%jBl-aYaaJt+J8;ta6Rv zu@k1s#>Py@nP%9c)SqkwFGv_Wpn?agM#AMlJo(^P4`)l*Ji~TL3xzG?_-@u%RHYpV z>888*^!~fAR~`3Pl$ZCO{6B08NY=WCrysn7>oqDM_#S;APMFMe1!8gP>_jxL69O>H zM=9TbvL_rH)3}eqeUv2$zOS6P9j)bMYyn9u1x?JP-nW{=WV2F^`DHY2e?W6aG! z+wt-xGIL4M{>=V8%oJo}IF1&PtX>5^Ce2sR-D=H{lZ!GB)f+vVDC{#CBHmZEZIx9! zj}sJ9`l>f`{_amHGtduX&+YgSV~+|T0%U>B7oR?92$1+JH2wcvfElYA=fo85b4WoVhOG#@ODw+z ziv}e0kvagerKmT6m<1X{ak*@&-5gI-+85h5CD-J9)RO5^MV*%52?~~mlzaSiFbLJW z?E*9(CqOp^h)#PPA0ZpYxHOHFH;Z#`NGxAP#q}FE=<)pPyxxh4i)*9`!wzJ(TG|kD z&6{DvK0F*F@dDoU(kO~2!g=KC(~$(YfXViE7;{HZVL? z4PeRRO%~8+^?ZG|u(+5Sze}6*=wkjW?M8F4xMrB|63iH0vg_}v>}B5ow^ta{_3NvP zJyKmPr%UV5($S$s)epn^wCOK1Ka82F(Jf|YXAjQ>lH-lokfQKJql8M@VP$G4hd6S8 z0ghNoSHxG{ySMrB3vDFBk}|iXM8>UO>{_`XAJeI?^IE2r_$U|}7;J#O zIZ7eavS575=rOOfxz(m-WL)5G$ z!F%4KNU^H(I(lIZ~JYMN~R03!(V~SkN{}Nk0iK7ka5ab_%jsgy2_3#rx3k^~~ z`Xp0{%^LxvUNg#-CvQ{GxAFP(L5#ATA^8q^bPT-~jN6~n7a!WUb^Fep@h}}s0;bC2 z7lld$mba#+Eazae0_w`GV~<`ZPb~&NLV&fbeG46SNfzpGEX~|SbNpXSv&MPfYPi(|E#?|b))=vG+Xls?X(T?kep zLA}w}2%$8*#8(LA6#A6lpR=*uUxXkQfZg+nDy#O+&SJFRuW9NMXedL2%O!PTfCN%czFp_>D4Z?^~U+e#Hq0r(gm2EbDpWA9H-yAdliFey%& zTViiOwSOPl%GfrO{wH4bk~2|S*rIV4@q0BNPnAX=%Xljz-8AK&nv>(!8JTwG6xwWH zyxXv+=%8IjcZ=VsWh%UP?+}0axwXGzFyDK5XltpJgic_fJQ@5Df4vBP%pNqMPTR9* zz-lkWrv1P8kKN(y5!e6m;|Gum1y|lhgK8|4$QDUS$uC`d$3B+dyp7#O(;i(;JnRaa z?{8l1H(<8TO#Iu5m{iuA0|sD*qPfF1tmHqbh8@_sCL6sG_fQV_Ld5P4Zf@Zt9vFV6 zU_Zp_0XGCM!Gyn3J+L7}dL!l;_ZC1iI|m2%=rN>pGl4NJcbfy9IqAxQiTqAadq7Im za}NLJPZjU8rsJ{Vd9#~sZ1*)SbD<}EEKKb-W!5K=q&LqL19mR_^6QMmuubsLQ9zIrm?$Yp2_2c#%`K50Q zdcK)v^-OaW%O5s9Ycf`-=HnVQ{)HvIIJmR#)7XXAXFInAhxz?b5HbG{;%J}im_8}( zbDS~Kp@4hYZl{}uNW*shvkwyNI1Y-?RV1>EWNG=)eSAA1IJ1BIRO9L7rR|*A4oCF7 zQrhe2b2hgA>6UVF>w_O_3~=KNG&-8~e7vr{6)Q|9Q?r=8y&t2;lOH*jS* zSIwfyk*(v(wCE3|?Pd)HOBr-Uhx*02Fa6jNbNG35V+0RPy`}3e&M3?LSNDZACbqMr z-!FH1wU%5WzwT*Vk>K^X`!B0MY_$_w(OIv!e82K8n?7H%s|~lNmzyWm?wFjBvqhGJ z4|&WhQiZHb^A?Bg<19myJ(5jl%!PB+UhfG!dv3vQ$18CTa~i{~jKV`VgY9S1T9(_} z9v`CAXZ?1>>#2>Kp!GG@(Y$?yCSitUrj^Tk-*ek*S5gWpbl**0*5J8tr!;l#i1%(qhR0wHtdGd$;J9Z}t4pbbZPE?Wo>F1XnrTrM_MFX|4D-L>ww`)lY~V z>OVD^N~_`ifj^?ws%e&k{*KYXziv_Gx|;p_wg~>xW{?qRP*M*H;d??ocu7=3QPfj1 zS-q3J^0EzcKGnh&eU0O-O$=^Y&wNWhcL`?++!Z(e?U~4T|L&!y%kr$&m(~2k|MKk| z>@#}(Yq;=hzSR2NH?Mst(Xsz4*W4YL6J@uVZuceA6_2dvjeAUGcNzv69l81QQ0??? zT~65&j}*s5L&uW|?U~em`4P2EQP(5&f5`VfFnPxRf%-vUSjHyW9gQk_OML>KTMvlu zh^(+u&XV)f8u+YI7GoK+E&WEYjMlR)>f+m61@8*R7L+SAjn#z4k+oPO*LrxB-Hc=%??oN0Qz>73bvG{B}vnB!FhsXBYxzqs zE0d_nm4xmu4(-VTz!{p79fBZpgI10RZdrtMda$hu>wf@@`WaMk7%&jkPL1z`j!?*- zQl352Uv|+^p?JplKbpkS&`}dJOpHOj0oNRR6fBKI7x;dnOeW_7=o&y&BKznAN{m>8uHoEf#N^sS^U_ETASZMAD+ap2ocE<$~o-$ z#4?T?#lrf;BqiB_7cn$9Eh_|f?Mj3PS0BX=V-bR?$JU~nA+rRy2%#l82{4n2NWYme4fydlRNkwH#awm{Usj2+82;MZ8d;w2+)0a_prF z{|Dj28W*C-`zV#I;MIlWh46l#M6Txzy*`uQh!%iA@ZOCpn`=fA;5I|7ck>idaBm&z z-s}()ld7IwV>3mFIIvojDqQ8|GI%42ED2{j@y>s7S+ebupD*SKGN`)&R3=pYnb@|5 z!uSN#hHuZUtsveIs_f@v!b<8kxcd{DA=oU^GRM%vg;m$oC?XXiY5yNF0Gy?V?;Im@ zT9gO0OnmLArv1%e6Ow=$S~l!o5PerCBmP-)Q~0`i5|1G%shjD9TAi4~)z{Y(UHHXqVO)?m zz(NN)s3Mr^78MEZ|24i2IZ3E33HC=QEBMvxHtkbGSRgbW@YdtJ+O3P{J^8wJC%$6a z{4%eDGRpWr%kw5uTpvl(*Z!SMNbtZ@Gc#Ke>Om~7F}TLOhj>1t_&^wX}Zv?~hZFHI9t3 zvl?%}L?bmjMnove9Y>jZRJLDvHEAI5oQHb zZ|h}+bP}ZYDKCaZXbGy98t0Q_d6HZPWPo^8YsYU!cL5I22O#%F_)Y*fcRyM*BaQlq zhcoUh1Ft+LN^QV+n?c|(N&e-pluyjHz1}sXT)y0f7mm@zqHv3nKQSsK3~>9m_kBr2 zN$58KI0&Xn@=1t073_Zp4j+FuvqYj-JnP=IB5!Wt_gRu1f;kjY7o5!*C+UR%cSwJg z-nVrVVk<}l**#BNkX7&~uj3wXRSD*zop{w02U*eC1P8sfw0I>lnuo#0WcQ7zT>SK=>#8`OU?xu@=2Itmo%G)2RgbU(R@{fULow{)>X=vsXU=8v<;g*tw zo7(Ri|4A1f7);!zI+_)Nxz9dOY((KV!RvKD9Au;9D_35UY#Z!AjNQsBL1P1lcdaqK zUiuNu>opD;k$S2yPTMY0rQ>%KqfK0>YOb!1g#`z&{tx&kPdwyneBNtB25rIj!mt5D zJ--pB%f559LuQbM!736s%-gd+y67UBz?9)eHXjZ;^7?{K2Pz@-P0kpnfQ2OFYhYnP zMr&k0gGMtBC`EBGk%0ODq{nH>goeAUtbKC%_{7k!HGx$V^YauuTmk=9S=Wc0WSbgN zWcHw1h?Gg(X`k~vvB2rFgXNxsdyDNl_NMq<-?Dthl3HElu2Q%oTh+2lW?&Vidd{av z8$Ml{_fp00=9)%FZs{{t7&2RP>b869xTo38xiyb=Pzy#4ZjYHWa4Yb*!TqJ)nYrM7 zUkq(j?voFfrOvd>uh*kF$SPhH%Htw=KTKFd!e#u=CMJuUnU&PS!f!TF_-I#r7d<9= zZ9`l2OyZsR@Y;T@KNd`eu7V4G*(IN(Q`L>O>Pu;!Ej4`lfaR9w(JVh2Gs#-_w_2wx zG>Z%i-=~Q!d>6L=Y$C>@(6P@zv%x{bRv^i8l)VCi!+V+(@5&iFp5eP>ayMDc^`MItm%5P^%NDBMclv^3Eqjva zuBE9}$I@9=9PPLtkuH)FRx1i=hQRlVgq(v_0eb|L_N002DXtKCe}A5p-RG#wv%c;) z@taX~qA?yZSCvF=MkG@Y)ms|6@~~Ws;y2CK@-$14(dsZfMucn-wT#) z`h2|)CHkY{9@TK1y7D=+|Gt#Lw;C=DQLRne%T^vl)M_=`*f*QmM+6zBy{x9BtawGK zE>vW0Dnv0o8pr%5n}U<)N$>1T@pI;P8BHTkBX(%Egx$+-eWow9F>59D^K~J5xeBB6 z2L|LnOAT-6(&MJ8QjyUeOJW{LQ|pZA`NP$Gq88E)iZx#^Q?x$yUnMRM zDTc6;--k7?8Fv%MJ9f&R2dBLo{HHWn9nHlz=sEHUvjV5c zzgrXcKq?`EPn>;1n74uODFg(AO!_ey1AQe{=i0~5pI_ZLv>gyF$!>&2jvbJ0Y|D@T zOz+%!<|7NL0$1Q>IPPTVHmmS%mb&|IVnG3ROBJ2^OUu(4f~bqN*|iCco7i>m)8V)! z+)w=WY#b(-yk1AUQDIwmZ~D|0iQoTzbT0%yBB&i|1ZdS!nt@saP@OoJ0PzsswW}-W zmxQDw3ZELV2guc9T0oFlh~dxE?mi0ZujsD}=$BsN&m|yR*@mRP0Y{w+zZ!Tcfuq<(#o|D2fA%mOj@{4=214`Q-8fUzS?~I=mQWFiTwy+lD4zX^%O=H z9=LM(vM!)sL{q{F;oGl?B;M?Ic%TJ<@lh%`)Y>E(>`Cir0dbvFxd`0)j64N33yW*s zg#1n@Cy@VJ)w&IwySoXwPEF@2F^ML@3IL#9G_Ed5M*d@`eqwcQmu9Lk0R<=QZw~x< zbjq4tTS(b4(Hzp#~>-7eoVsDSAX}W(w0|!IW{*dm`;jrZ^Jl^;Uc$ArUj2VGUJL-TAN3kWpCPKGK1%6MP$FY!Ye%@HX^oazBg7ARpr=1iq3`3XvpM zri*95X(!P#NS<8?N_eD#$unNoY8=7@$%C&lSYM17Ya?oeYsjIeGbU-hZkMwx%IlAK z-wOe_qX~eCcw*H*Ma`>kP@M|w-Fq3G*D;Kl=PSPE0Pe>_;1?2#1Q?5++mNs#-qhWp zi~;v4cyaQ?k!4jdpooDu;7_S+Lt^l$BchFWiByP)r#LUaMhe)-{vnZB1le_?yzhXe zH)7%<-mg&EWEdVhwiS>q+Kns_G**8VlX@FLI5!MBd4%Z!4PJ~00fc@ zY-&Kh9YO;qdnggrgR7~>YOsC~xC`?EFpP-6k?bB0yTk9vgHD7^6o`X^cdBq~#?Hee zi3b~om)^ppMJ$rSP@me>j9B5P(QDJlX?#Md);6&~;ybXcCsC@yve6Dqyy%bH)qn_N zEm}!LznS}tro!RN%ly_X(#v4c$hh?I$Hm8ETtx!v&SC?_8vqg}>J|Fs-k?h2c2x8o6%3rw1 z`^b6KJ8&zt9@mxKmM+O!f>H0ckk8jRkRFT=liBU;-kp7Cgz0L)MdO!_aTNE zNOpshQ7G7;QlXl z;oL~^38MBH9mkI-X&bqHH2fT3YYGlBr#!C7f!XU0Jl6XjV*pw})fWBA2<8bcy&Dk| ze6qJGEj{wP%G58kN;*)h#->QW3a@=c{h>&sLb>LJnzrWi~o_%HL&07V~^-4~s+m%lNak zWzNmF1YDs}k80nm?{<9S`)SW-gB?-gVW*ekc&NUNSentk0q);&Z@cS1IfH6T#Q<+= zZsDsjyXa^4eZ|{=N z*O_B^xc4yJyDiGDtsNVA{``-#@%Nm~4@MSfQJSW$vZ+wT?cct*jY3`RwyDcMIh|Yk zVv|EzV;MNLEJgL0lBL%R-`Zjr6Hon7i6ewpwuFJ@x{gTo&f6&x6elQ8UDAKU<{&Pb z6f;BhLfYYszE)zS8`Y^e`aQ3v)9;?kIrrgjDz$1-N`iO{ch|gI_Sv%YiiewDS!%?V zX`SAlye!^+(fq@DCyC#JmzXx|ZPi_rf06WT`9;g=?JQJXKNQy&DW|m$h?_NtJ}T4c z`PiP5borK9%vWQLXr1T{H<(qjgdZP%a-||=G~$5JF%wtgJPi~}L`n>1zm}-(@=R4yMA!x!1#j5A-WUkt0zFwqfM@cnGmMpTa((_Mf<%kX`6av_%`6= z`j5(YVuK6KhU_k9Vx)#;O=tn6CkDW_A45=DpSp8vVc=zsiJz(xQ3Db5Vw6|v1LTHb zU|MM_`tdu#!E2=Y0)$B1&UM(~F(@S^*42E9kt&Dv@tl(Ka`U14 zxA|7jQv5#`fUq>7+K4#O!~m`VAugbQ1{Q#b2iLsg2ts||d;yxhTKvX_R}?e=WbNd08R?M+uB4Xh4Z>!{*5@?pcG&#S;s9^` z{d*;hOdfGFJtnnaAc;&|N3Ohc0BivpwEr?h?S*Nv1L%AHbFVWxI}K6S66-oZ*mwqn z5BL4k3&yJ~N1KUehJjB3Boj?b%gt}_S6TUu`bM3jOb3o|N;^H~DYHWjRkdNpYf~4$ z;Kor^LPCf1*OR|!u;W+bFAk%VomdS#xyuhzzrf{#rn8O!^7mQO45G+f?qk=Arlm;N)H670WU$yNQ77% zS{cG%25RckSCT03cc@y@rWINn5FTw1nh*WHL-iTWEoCq(u2vstKjpv9j{PM&g$}OO~8sHt?#?!CP^nO1^ara{k3Zcu80$k-k2d@RjgU8R!!#d36ml9a{ z%@E&Tzp>l_>H?mCqoZRF=KT&5WR?$PQxo1hkTanAlj0)kT8dvQtCp<6D0iP=!WBzW z?OGLT)Ad(&2*e3O+FOj748Afc2_Wt4LeF9Ij*T!oJcjMKI#x20>B!n+2YGpU9+`N6k>veu8lFLy zc?RtTnTDYyjUMYxVtiCxSx`V>ok@yOpV>dj;Q0bF-61Ktmv!%Az*+H>++6qBU1Vz# znCAt}SM}U}G7oCcI>h*BoB!HF6i8KQJdsbFl2D1~y9(3G@r59=B*DW04+}{-A!-Nm z{$QksR<`xGCo-+k=Jr-O=mP%-kE6B5MviAsIg$sa53*kXBSC0~8=7j>8%B+-cNQA= zteUX$yYY425ad4m+ILY(D27URjL~$>%lyT`L8m{O;%-#bN#{e78GCPp-&r~(BzKla z;GS-4Ub%%`8B5pRLJpUkRHi#a`)O`Ch4CAe?` z&JwyO6Pq*oA*A17OxNtD0Shy;$@2z2PO367I@MJk-{%4pf`Y7q)5}L3Xee9KD!~RV z&QfdL!j$)D7dQkyrU`%cK23ggmYGKQvf8u{@9ojeyRQAcM)%%(qCH@PSxM9nhK*M? zdfpGVHH|p<*BPDn-WNNfdDhh!3C;v_-ZK_WFECiilX}waK(}@7pSeX^mszgJ?N{e! zv<&+8ZFU>m(D%#LaAxP}Qv>Q++D3W8HE*c>u4z;s6G{jRWOK9-vJreUA$%o`MVPWf z$j;NfIrK53yE5C3)O@YlatoWgXIt4Q=u>q$x`cE;YS~#*Tvyjl+*nW+^koO@!^}J$!6LfbA}4RIrJ5Y0>k{(0_tr*uz}9FmJX_bqhEM2Mw^)o!UVX*P2cv=m*FnyOE` zpf9WDjT)@)JC}1K{Q))EoYPBp`OnbgowU}o6gOAxqdp;Z-TjOGb9TMP=O^@5T=Qw{ z<9^y!$Ji`sObX}57jO@?S#Fn;T}Zo7{ClVNJ$W{MZP89oBQbZ@;n&&Vix%$j{qVqv z`oJk#dht|S3iW44gi}J(;xm^UBPi_41H(iwp7$v-Nc_2<#a=m0`0-6|3x3K=vU|99 zr`mZsh|bA(Ek8>aI8QamN`I)5STfG5XY>Hr4u1isGW>uK4O{Ycq+L zy&pWLOqE=gQY9pib7Ady#^LVS(Yvl}0IUeX5qzTFLe2FU&>3RG1QTcnlmQ=LDnsZw zQgL6=kC5rwiQs7;JIV*2}wUgqG*W#x|7wQq;Jt+f%!o zq}rIdsS({ee+E#T{^@^n{fGMzl+b5o?xh<_>04K_$^fgA@+zNCfI`Dzpu$8G+& zPkt#>HLUeWd+|x;*ih`{%=SbRF{5e_Ba>idY%&j zXp(vDr1*ZErt+Md;)p^;hBD&PR+Z=I6m7c>Jb3pKqc3DWCBFabD_(fu5!i@O#=$}t zLFrr}mxT2JsQ7>=`qW*2?QO^H3n(H>ZrqrnCVA?0B+eW6qwo3wP7-yBp=qS$$b*Ld ziH@eDTMq66l7!4F`P_4Xzjs2v?cy?yX%=|H{oh6uM@C;&d!oQ4cunY#-&xz}{U*SF zpr!y@lUGaGrPv!1WE>U<628l0yCq0<^{lPI{1(a3Nq;4b{!to&uNBrj7VzPzSKT2i zE4_VN_x>^8dl9Gc4@#h891vYA?qCsqAp?#V;rfvd7(g9nZ+PhO&%WnZCj>uOoOuxD zVaiU9ou6r$p?8kgfV>F4Va3l!z%p0HLo!QBMB&ljkv95Zio@{F)?x51uCZEeLN>g4 z;^f3*l=c8Fu#`X(15eTb8VDkQhC9S#Q~ggkj}RLKxspn7&)Ji}>3g zIllMLlKHmn`BJ|W1yxkGXqyrrjH;?C3z=|C<^fV_35K^#NPmtu-aQDV25?-qpI1@d z@hnf`wb3~~$??p`W1Cm@egE=-7yliRDyx4@oK)pg?W#8vC;5Na_l8PbE{yNf?>T6=R6UcD7tT8mS=_fvj{ZNJ*hN?My*YWmoSnX?;R54#LdE~x0c z9oxXNDjMEX{au0X($~!rG^}n~+H&;83B4cF`|ro?I8J5pEKej=safYknMRO%>gyrR z9K)ea$xUWX8Cs9t+T_e|Zg@hyAfEl)THOeqQhCSCwT`va_8vdY>!@|_n zW^?J!4CXlAb$sS*cDvK>8irg{dpAp`dv(J#=bLt{<1vOF)>OR`Qwd3Jv;-ibPj*%Kqbxo(o65SkbcCVvPKUnj-hM&L0jnz7!Mu=K9CY;ky zlyZmBy7w8P?+$(Vwa(2`q^L;g!c}j+0L`ya=fz#V8!i}FhpZGB1nEb#gk5;svAeUL z_2*Af!Liob#%;PkZoSvn69(dw`-&1T{$VwZ2M@oV%QUT}ySK7ApfYCp{bA3^mrHQ53a*WEKm zIB53fY}@D}+`cvY*M4qyUiuxH8q74wdnUe(b4m;Hre>WLyeG((dbA}hhu)o&%`e}; z$2DPi)UoDa?)mg<`pP1_X;GsGTsWxy^Rk#{8#rPm4@LDGt;h| zj?@0S5tGph)HKcOB#MttG!3dZpUJ$C)3|}ZeDcB?nG`u-=WL~dM3 zH)N-%*MGx%VLmmVjbb}G`w#oaB1R`Q4_+3;h?Q8xciq*KqT&}21*lxyIqw>5R$G0PPG__~U^t#5Q>CJ9e znf}mVn_igP7nV=glUkuTHCPtbcIC#|q4ifs;!TUDy`#zcpc0U~F#T1Y9m!{YOX93rdYc80(I7IgW zHHR?;iurY_SJT^0Q!=jY4)fwpNz*=XaPlJG{6$BqduNA#-x$AX%!C%W%%a5e#k^|p zqS*RPTlZ0ZR352Xn*aOwlsEU-BuL+5k{2b9Sr0WJ>0yB3wD9 z?omjl@&l*m6=+y8GBRWr!Mx%ccrYMh!2Vnn41l$nYD!e{*tpk(E2g)}p0;`UV$J!@ z+PkVo@2^if!aej;Z8}oYOiS-UDbfG!LHsVg^|2Jg%&E<4Y!{?{sQ0y zFLy>=&LUal)4_)g}o)Ai%xZ-(zj+q1<9MA6u%N=PtQ^HNb)XEBGJ45!{){Y1Eh zdHPi+EtR}o+>`T~Rlf80_3sVLB*^6iznM?E`L{|Zbz^kbmHFM>ib_iNoWGV{{=2XH z@5*OnLT~xD+fn(YXH@{&F&t(v#n32yHW*y3IyPkd<>mK)uE6N6k}+c;|Iy<8KI%0# zHQ^f|9&Ph$c@TxbYL@|XDw=`2nNZX)_9Ij)f{_DW(omfNrAd&&KIxD!>wCd;tEv|#gq#j8t4J|s!WhKLzh$jU;3F{-(A1^VUFf5$ArDXTc_&RixV*g5In_TkpWSb_ zzP8aoCFH6}XyfT>89*YYF~SK zVrTX3K4oL!5*uRIxmw`YMc)E|C1Y=FyhD?k+EJIf2uqZ<@6npN;) ziK{qemLYKQAg%T-#S9lImk@D7|J*q94KCvn0Rmf9=jDo_ia)N^oP=@Vpn+FYQs-uC7i5edsbOC)4w@?(yJ@yKoJAIkXzw+|23su zNdXT2OGQU!)ISzRPEmE7&NG(pbO;t%C)d>9babjIIqYdHdVb=n7M>Tk1A z1DB@{3CRAYIwRxawVuj+je8RkDSC7t&%YB@{(EovG$5mbUD%l; zUN^1IEd&t=8!aB|C0>}hwzuuy%;y;aNr8ga}KQsK+odY%9)^MQ*OKfFPm*HvC!~$`Tka(f*S*=}+E{DwB0|H`aKVe{e;zVJ=7b2E| z7Zs`9L975h%lqL1EfFYUe}ogaEQELPVHjm|ZeKf{J4aDPQ|?)Mdu6ROK3RdIPj1(a zjgRgCp-WM>J#^?PNZ(I!ilzM=P~+!cl$%YMMi8f*5nL{zQ_@u6+dKzvSm-){JROkZ znY&0gUY3;-iu*A{&pi^7;3;-rNj zzS2H@wBQZ5<7?fM&HxH;xP5My8*in{5Zq&7n}zrjopg6wC}Ox6Am+-@+H#{@92+-i z8n*m9LDa0LG=+JU86FNyx4_i+groWsM0AdrQA8BYU{JqVuk?=l&wS8O;uM{_KY>od zM%xqoCf0v|U0Z;UX`|^ZOvp|V>bt3FN_u(}B33o%K6XB|A?V1U62t33{PZi$K# z&MAF5fN=+sgctgEG#g&C;@>5Pn6~DpK|9hYY(}jV~)H=eK5crI@4z#{&_IYn%m*aW% z=*3ei#a&DMe+3@ySGUuqt{ps2ZD}oiOR3%RbJ*+wBNw5v09xmkrhfH1vic8RN6u2p zQY{1<^|WM3vz;uD>^xxRlXS30nIYni)>PQ`ORmG7G#*zk4Qbunqc_{5KP!FEoncXl z?a@muWjA}B;48c|6))A?<)j2f?@JU;9{S_!*SWFOn#oX>)|pdZeM4X6YInktg5q4` zRNG_C;Xid;JD>6hEo@EV3H6SYSg&4}5Cy{zm9%c%3L^n& zO)=b=`{Ss;YS6qEc9&SZ61>M&j3r*T^=_E)@)oo8I|NgF3ptNxZ0y*{61-Qv?dnsP zEe#e|R8AV*9IJh*hgm=S6MQx4wD6M>29lw=Msb61$+8+(- z{`|DxZ>=|J-jMUN+FebZJ5`B`)Q4ujEJ{eq4*s|n&Yh}v!#_1FaKqG?{*#Oot~~|T z;aej%3+fxx{^^hTx#Q0Isq`@44fU$uQ+{mP?-w@tRHgYuVu~lL-m#d0vcb10{8hSj zE}t2scn#lvVmd;z!MRiT*iG9@$2!U{-qLeC{#D57;X2KWFZbG0AJ*Dsx)5|{slV<3 z!@(o0)`zbrpUOzeyZLYPmZO2&%rC54{Cb7Ig#N}MmC2(jp*JG9rF=iWxc^*YTD)SU zQtN9-Vw%dud?AikpM!liKNr5l<^PyPn{n&L`&n6BTS_<{9?`5gPi*Nvapx%KZ+1&H$hmUu3yPbp%g z@iOyZ(zHvylFdO~-Ra)*N*CyjHOnbtlz-lGWXlKE_NNw5HpPNC(USd0&lcB;JH*`(}M@qQ1`<^<2rUgXfR1H zg8uA2DA-_iU1?{Z?P~KlJ{+QX)52m)Pj7l^YV+9Y&%GGCfqCML=_UM6n6yoa!UOY& z6+Wxz6EIAUl5JQ%C6q!31vU=ZyzQk9I8d6O5FxGrD-Q#$E6hp3AX(@4W$lp!((&l1 z;nNAgah;{8vmDQLSgC?@y^(1OrqsBgVM(FaVf1zN3%NN^ppQs~!7&5)B3vUec0>sd zPn8=mRm2wnoMMKf2o5GMlfo8-kRSKg=@yvqd6=ByM|K_yySQoFOZZ|&S!&^+EH8gG zI(q5$jK6Bv;j_-58kq=jU#+)VgQM; zAgF5KRc-edTsfWnZut?5BxJ5u5~4v1o@q|m*l$zzc=gp&N2+7Ojja}(dPC~-9=*$r zb*?!&h&$#<+)ygQFLrx&Y4m~Ty+|h*lK|2~4w?z>BXpJ!Ub3Xu^csSdiGGlyvP?`T zfh)X(5U|w){1HP`Wh>uNPotg9utn@A* zR^3&N`M14d=2~i|{-Tt!y6eIMjeofB$PG7X&_YrF~-28_qF05ZZ6 z7YqW7u4oxthV_bxlZ*1#`;p`iWEjcO(b3*=nu_cjq`~kaiW;a2Mk-H4(p?H;`|MW_ zSQ^X$@Yu3=_ZSxaKNmpi1Qq(^gJQyYZ~s?JWt+!wybbUsiaOXU5WNpvtPrq>>-7aG z6Gbk(>)(=4qI@}LfVpR3SLOuK2MRy+H;T+3^!nt$;p^z=K;+{O{Hq1b;vn}ITh#YB zG2`7l`;nNIk#Y;KHVUKYXj*vg_`F^6s=Fg;=2!3_w@3`iVYXPUWMTcck$4!c-riq; z{KE$4)KkX2K1e#)s8Ry`N)Oz7p}+udOt()?eZCfaTJR9F6nq8A$P%Tz&;asU%FD@{ zXKtu(#=zc=c6k4wtMNrE^Aw_kd7w|&A# zvv1x^LT#na{`yuN;sx?wwFDVk6aFixidx@9m6eruOghf*{2BU&$>#N6KNS7P?i_Xz zVl4m#8 zNeL$jHxQqQ|9n5{Sv`N{(M-?s*}5GPp5?>vmt8lw${VXiyn;a-i)aADR8vr(t>k+r z3?iT>L}DFcsGeCXFu%*)QqbMq{iXWF?rV=9AIrOngq&YGqm;o{_t=RraByvEvLhC`_Fg@uWwl{Z>e+d zXs2iEWJ>Fw+M4H-X&kfD;>qcf_;2gQgF+PFGbYR8SBYCQ1&Uax`&n%}_>c7~_lDH< zJIgqwOgBvnGtFdP(XY}Ny>nA9XcO7N{0gA9~4kX~*}fUy&Xa+aHZq=Dk%pCp9si{jo)P^uo1^ zSK5`I=sRAF{MHcVbL+tFgN<@AVK?5t2>g)$xpCPd37YJv{c&YtSxjl)c4yqxS@*?``($SHj@BiO>{|yy3-p9}tz>yMm;8wkRw;IP)sarnp+xI<8q0his;R@L7h z*^mzij}|eow`DOfntfdxk0~AL^}-gfZ7USpAo*yK6?m^8T2>9Nd&R|C3a+pZLP;dL zbLU2*KDd^xlsU~N+$QXojl1uk2mBCg_gwPa?-yZ3MJ29r5(kwv-@Sb4r<8A)D{s`) z0){I*qf3o0kfYG~Wu89HMWQL|-v9TO&&z)^jjK1?S68l9qUs^DT#%!QtQigsq9!U! z>?jC2bN0kT(|iezQQOlZ}!O`Huy?dHos=)`Dk>B>sd&^VF1K+@o7rS>9@H`GZ5i9 zz|-@{xs@dn=!YKL5Bk?@t~@B%a4QM0KnqDOzBMs5M4U4*cS3;+wlx9JVe{O^A-*;* z!#iiM9Qee~(!$a*u5%a(H*!4MnNEC6>&chK^#9~^{|yizzhlfxv^LOSE8zb8ce!vu z#F1Ony>`Vp{tDc!5kzG5!;Z$J)a2^!pX|1!loF3f9ZTMM|)qB54^vz{p;NnU2_)gDyvOt)!a(dAinp2o*c9Ey9r25@Z3q zcK>7Q0lZ+SQwg5FXXuD*f#~@YTwwjdn{J1U!JLW&z=YWo;0?C6wx$)WH+CeFbWo~L z(I&uJ^fK@rQwRuA4r8)>0=HBDYTcG7AXcDWgTbw;dc5-!L<@LY64keOd3bu{82jX& zQkTQ?EjCD96kB*s^MUulH7OLLm-1AhAr0E2>0@1ZZYy!s4p|U2}o)jxl736Rrs2fTss;Qtr$;14xrH zvnCYk&vkC)dF69(^hC!8T38!cpTy4*?^xT+k{<*EKVTu0jB&p2&&wGZg|0k!9qiGt zL9<>3!YuL4-Ex_>bIPCi0@>Kw674mi?c&jq3J2z?WF!I=LF2uruTgj1lZFSP&LzGM zlzQFxFyt>|NoDRn`dX60{?2LJTMzDR)&?~fWjO3jb5%hPB&O8HA zEV*8>UAsQrd*k)=%a7hyuQUKW;xm0LqV89P$aNu7@_< zHHOOAvsx{T|o#dY!Lx9J%osEF1Ka;KH0IDO@*h==%Hb z#ok5q!7vWEiSij1omdm#_OLpwzLo@&q*&mC#B~CA@Jn=66s6OE<{9_v3~-B#vxNDpRcjSksA`>2}`z!rV_?j$F@#mMuiLr>i|Dao*l z85V^n3fhNf_-m!s+wl|&+l@B#r=nxO(>#1F zGULHBTjKiR-?lrqWu~s4Pg3r>I#&LBg!{gkvRrn$*5NY>&cTwJ^CLIHcCDCG_Hl4l zOL5-Y)$3!(drM65uhVzS_B%$JqwLPTZ)&~ElP>#Zy2>Pu6m>Wx21ds}qHz_h-XSTF z-*PK4I=F~^!ts7l;@PVQMw5EO4~b7Qp1-Et9d_!0riKgm(WDTO`S;SjnpzQ+1&7}F z_Z=*rcU+zSbD!d!dlOBxDtB+npyT)4e(vq&3oY_9m!cCzGru2wSw1;OJtZua$lh9F zu%u*jCitgIlG}|96JCcs(_7f0!?+)sUSz-D5l2JoAL=AD&}XtbaZG(ntZPz{x9t9? z=axeMD4AD(xyBagsOLByi~C&^Kcu~Z?<-s66U`de(`R24^0o#Wr)>0{@GXp)cb zf3erYO1$h~MYHVoJEc9re+L>Z+`k)gT{dWTrWLJeDw&ZR6niLrf&FXoY*$~x@k9P6 zKHSma6-#m6Emm4}ry_ptSIq+7-9NMRE9Z~Mf6?dO==ZDjl^$ z>`l_Fw#Au&rVZvl>s75@nzv^+i5xleDmElIE+u)&q_t#|WkOf2yvQ^4Sa$USAt!yx z-KT_~^wjZuUz6V?Y$uT4EOGmz&zlRQUN!s^`DYw5pBm|}^YAm<*%ymR48|Ns5LHiV zFyfm+Sbs24MiBu(0qdkrq6px+R_t5)etycXmB$<P5r%q9cLWfWL(=b(v7qLE9hNquYByYP#;-X=2;*?w0I{@bX8H??St5Vb(}gnj|_1 zMll$v!pHAB_GCwBgNQv8$smEI0`-6fz#}+Mz?sS>_@6gsURA}%i@O#doKqNYBBM&A zBzAc79=SntwGKp6E+tQqXGPhg%A|oJLun`*h&2td|G-&yIKA6qqluE~=isJCWhSUAFy7pet?$*t1&FiIej|?5s`L z9MunQ!~2lwBZ=@ySG$V=t{m=p{dYsZp866i zAx}1D*`VHo^U^CBM$!ZkDpkif;Gd|;q*!*6nJZc)ah!!$312>ZfEr}mLVoa)II$Q+ z!VEev(OtWe5R#aAwG?tD!By_~249$!aG^mZtphR&e!!hdEhdZZPC6{XgZ_=$&Trs` z!AHrdHa>_k^x(EiS_+bwo}H~lZA-jkN++s7B_-z^Q4r_RijhPnp#(a}lh8>pP}jox zOAiu^J9#9E!2?d;c=!e9dei@nog%(05;2;{wSg>iOpl4~&?N%R}|d89;vkuEY>^aDLg;aF3VN(Swhh%3!BNY_f^HI{-U!MM&5M ztN#3A0II@rgfBtyS|{?mAf)-%eZBX~$G<=tfu4r+c7R={Ce~Ib0HneJ7daRu`?mf= zHuk1bX$Zmbw9zI1##g!6Q}*&%;75wsG zV~JS+{*JSYi#@t&O-uiVI9|f#ATQmpVX-THwH|$x`}1B!O04^4}2i~c#1&|6iJ6?^!Dq0Q*1e{K% zq1Wzw|318E6kbsvQ0w$JJ_ck?+E|3fu128bBqqUpK@0Y{s#s`HxB_fv6V1IU@EngZ z2*{ivAK`mD;H9ht2J>xY@4|h?FQ8_|G{%rxzdDA7l7f6X@zQ7Qd7QAF+{AdiH_iAV zKxzP?A0oSUu_NpP?r=We6Z9nT?A&g5F7RxetX36{o__)3VN7nxchA`|Qi#?nz0EokY z8IHnCAQ_xMzzI}a*WhZ5rUCxG@U*>B(cz!U2TB0RQvhj%kO;tw5eA0=*#>n3Xi}s0 z^>&xSu9j#Z$r=RnuzI3|d+;C!S01w!Ts~RGFl=g**z%I_vh6Lz(T*G;V%4J1b;`lW z&ck-+fA+$V+BH0^{1($Nw8h=@SCx3srn}eN#*Zxrcf|3plQpu?Rcdmo-eEM^QII=w>7earaY3aW1NJ;ThyAtY3>vsanO3&Y`^n4h zd-W(N{iNP4G5whE#GUgQ)N!-%64Fd(O}6g3ENXPlPHDhzX3J7gB!}4xg$J)S!&7hQ zSJEdQ)BI~FFiYoh#pAaVo51ac)B-uLq8NulvD+J}6lwQA(Ur_fH)3L2_73IOrmMPJ za!on^gsy4u6aV3aXdT|v5xmz-umefq9J!zk;s zzBu98+4K4WAh|t^ALJgaOAycgaEyY#Yun}Kvhdyc$;Xv^uih3A458iczmB?!wJdy~ z-MMN{Y^KH+Bj&j6+Y{U#C}mgP_A0vCF>Ashd{!*`6h}KZ)3ZHqJce@xmPH+wy@kG7 z?hnp1J;q2&opzOy?bk>i5O(*?mn1{JZM5ds{uQ7q^)c*TXrkb;K$qf#>hZi&fj$rC zMBbjd8S#MLZT~femOJ{Tz1&k&N7Xs5eUEH>PJgtw*YWPq>s@a`vMt&C8{%s8TEshMtEk?qqWi4`{dX}3Ua5A5>@wXGr8^y33~r)4#tl&*|STFN+>Y;ID6Jtv~CP=4~`)8dFVA z8mvNx6)MupEki8Mtp0o`S6^bX`$PS~jQ;J39hB=6sOY~}zn_;Ax#P_i$z-7$vd#MC zu5TJEX3XBPr)SxB>MPZS9B-e_`efe`SW(OvW@~s02%A`4ShR&Yk*; zjh3Wypd?~3=^#}y2AiZ5Fgtows(m9Sg(Q3gFO0!}7WRP%X(H7h))5h>@Eg~!(~Fxb z@`&M2oaG6$a*7w3ib=plmnaE@)2*h#o@sB9fNn8_3k!0*Y0pmBRQ_`(&I1PLsUv91D z>A)}^8wN?$LF5WCYDB5aVlod(4>8^$J{XvYE-4iPl*U>%l%U6FkJ5!0V*;!I5(jo1 z=+!Z8OR(KRloVOai(dRS#HN@yY!0?UFlq>2RQA7q{;VBvkAn87zWMq#6s#w)y*sF( zhaeeBLGU|Umh>2o8n?P&OZIT%w31MtzsP0oN*_|KsBXy$dy;-f4kArpPmlvG~J`CyOYATYx!XR42h7FiC*{ zI|*hacGVc71D9Ke48R{`%La}Opf5nUXQ67;rL8R_?aUq6A*Eu^Vpx)Su8Iy}un)9i zlO>zerwiS80_^=k7J2NeB46(Rm!b-40gfK{fBarO+Cn&ddM&-mo8XppsK1mffC9jP zLj?!Q8m&zUzde+9@BYpAwSmQ9Rbr=Mk%B3<6>9f=vY(?8$)<@cMqCJD<*`M{Gw0R_ z5Kk=o)1OYnlW@lb))aXA6C)0LE>fza-GUaLL`z-W^UV$$5E%~4Y_U4t2TGZ0u4?N? zAp_7XnwMdd2@0nt{@P#uwJCnmkO742?CK(6vdF;#cIW|-`jv@K^1@*LbCqxgZ-kDbd1=}~WcY|5| z_26qt`U+sg$30$eg2$YuS>VrS|0%Aj1C$UC`HR!BH%4F*z!H>D+94 zLnIm{qj_s(NbAtlBpJDr4HL%i4MPLC)a%ygIV~0Dxp1Rk1zgJ<-%tk#3yDCz^t(&^ z`kpt>{4W<^)U>wNzjj|iByN41z*&;GK$_3niXILYedDlkA(r7#H9&cF%hIH8JJ7_W zg8P2K?|yg=xGOFYC(-RK#6A|ctahudbvK$^b^>~%p5nMetjn<^m>~v%lv$VyL1H1! zXaP^CX|#1Daw3#wLp}G{*cfS^VBvcH^`^s7YNeyjl?}SlvlZ2TYG5T|*po^`4@BvI z@z3#6z4eq7HbIL{gz#_?i&9XsKiXuYJ70lp<7BP-cxM52OZ1U-LhDeopiUe>^wUuK z+G0A2O+U8rlPJz1mggg@3X)wpm+WwjSiL;J?!-JM3H2^@=LQA_>T`EhZE;VKVF`v} z_B486e z9g5paoLTW5p50NWCIXq{Bc<_*)k!mt>SX2B{ zat0busx(vA%DwcvKlS$RkiUbv9P11xd{Rd+U2y`n*I}Id@=MaX0~TMLtYZeRbW>v% z2&{NSJ``_H+044#_zs#)DAk)98~ce~2K?Ut6LY6-@O2QB1VSky2~9X7tY=N(Ob@>( zV$11(_YbWusi?S)TZsVE7-DTpJCXh6CcJD2nhoE#h(G!QEd6-4p#haOc^r69K1##@ zUk@uK$r0GMee(&-P5%9vCWF1;)#-Gc-9cU{V7<1m?LvEv_e}A+K@+Ay%1eBI;#zzs zPqeqie}bZ_FR=ep;KJ8+jg@mXZ*#Eh*wjDFSaABg`a0e%v8&VRr?)iy%Bk47n((C5 zJ-7NL5)-C(WKIOp%m}c)=cT$Lm7X|DZI`O$Bx1{-`GvyhtDCtUeVq5(X5|Zq8h!R# zLusXLCj6}+;Zw%38HK7)dc(Nz;{xgJ6ru)pUU!^7yor9=v)hcpF_~@K{$s^^n?_8b zRuzk5+afv5p&-l~D@IlIT3fQ@%}>i_g|Ehsmb$WQ6FJP99ly^c#U`E1>Q3PbN%N2k zkx^XV@?=};lk+?`4(xJGbUqM$c+`3KdJE^K~MHorl zkTvbpHF*5?lfbey@M_HH{h|Dn=-V%5&n#}r< zQudC~jP6i9mlAK(5eXsF+w7+)U6pQ60WRrD~sX;k3xyuN#MA|J_Zi=;smTzj#hx zxnTJ!-@sJJsN}L}vc)#K!tseM`9tGMg_{^JSV}x;;gR6mqI&Y0bC-F|d)cA9jNw<0 zD8I&aG~Cu>W1;e{biI1pUQi`f*r_!t`a)xgSKw06;Ra1wrpFQwAD!m?EO{WnU-_M& z++DXfF8)%n&kbMR&9z!Kzv5c)_9y+|oZ0A)zFRTDnO7!9q{ibT6BxE@r@DlGxtIOt zW5;u&oa)Bx3rVf~K26oF{%7Ax;xsv=>;~@JvSu7Cd!3&ttL!cm=k@#P07^Kc1-fQ{qG5WJB`;CRKO=q_+? zqqVg)Jj~ZY3W49F`hhKazvolxyR{G!XyK{Eu5jlLH)-Esw@Q-Q{|oYk-}m~~pQ^V% zOACVPl>>!2_K`#+9_L)R=*cHx$dw1bn04qcQaJu`(lsp4HJzDIA#5xUn`^`F0$C!E0@58kNKZ{kW?93-!$B&40j>fmbH~~ZdRchaFA?~*AfW_kA!HZ- z3QXa0j8QuokdhXEei;g5OrnW#2;gR&OKq<4_FS6tWEvQ{^X$^8@Y{tkDI|l#@;TAazzVKB^kpM0(oZ`S^0ERo%5HFsf4^*0; z_yNJi9O*u=RUsZm*vt?j0;XySQp!>6nYL`%3V$5*^N_SOTF;MeP!WPqQ2e9uu{j^gB>^-S zV01#0kg~^(fL+*|VKY>Zi$PP>oMVvG*wmzCpZ@jbWjmQDdoD8G$3>YvlccH!r)e@t z=NCNynGIVsak&R4L>C^}9nO}iy9ur1^Ml8@z4!0m09=slUZ^M&^+gE(4n#q`IP%r2 zR}{R83JhdwjtV+4Jze@;7~~VA^cO62&AtYNiewhzlsN>GRM0aYA7qx4m_2HeVT2q5 zzw%7H^zdBScEU&j81j|*!-?+uxakq2BOAOL0G1({*W;!YPo+Y;1#57SCgNdvjyt{H zq0~*AT9s$xF4)SE{sf%O#Sg_nuOa;idX*7dGzyrKKqV+ualD}NbhM$;2NWa&ll9P- zc8Mhws*mBbVh@^bXqqGkgGYZYg&Nj`4*tvb#P1rLIcdnyX;CKJfk`$TD+!JklO==c z!AekYr%iF86sVx2GRND}S(=_XvkiWg4uFa%qwUgjNoRwr)Q?(rz`t*q4-;l_rb%5P zR#ccEg0mNNz<~(dprA5^%wR{XVEU zB5%8aFUM58@9AlUd`SynIC&K$kkGjfc|HXojtUSiGqU}?JCMZ=3kT4Km)_=s8Udj! z47CgQ*?#7x=MP#C+<9}qv&XqbM_~FSv+AlYhSe$hk0@F+u%f@ceJV{dF%0Pg zgnxy5iNL@(ad__?Y=Rh@v{E$OyO@F_Ar-8jPStF*$w{1QuFD!2u0X@1Iiu1bZZ|d^{_0_>Gme; z3)U?F*@HHdWMaQ{yVAPOApsNhz7yu#XO$YX-0FXQyY*8qIJ7xB;@x!$vvsx~TGQm# ze~jIG(!%se;DmtcdlBvJy^Kv%KHFLUhNL>249)3pyQHI}RFPp4Z&&(#@QhO!mFB6T zho{Sr=dX*_e`(ET3d8RrdDlXm@*X=iKPQC5Y;L!Jf zc6|5qzYo42PU&?^NK?BQnk;l=xn+;6t^cw3#%Ejh?6@Ox*!0LB_GG6yv#8hGk4jv@ z-wz60o^(8;_2=^*Q~CHYk52#LZ%wIlbI+pY%xI)8-Oo1;tU2Q_J2f?I-go}Z&&3y= zyVT2Wq@&zASn-d_%DYpAI%H3m7&x(j?&dG{EKKCEl;Ev&9w>l5^~ z=@de*JKai1=$O#?w7Fq?+dJo!uj6VihhFeLFh98am($~7QMTwLh773~A&IK7iBB{8 zsjP?3+4rIO7rcP5uGJgF9Ul+R>UiJk)wQMKHMVrf_m;t~3krpMDa4i|hijHTxX|qD zx0(x*;9%q47zj=r(~%nh7WTUKeZBo@ZX`gh&ww}o@i#*V2`zV<z=U}utGc&{fxG8M}jhn%dQmVIym!}`=KQb=v zT6V(@4U!#J{F=FYg@zX|R{C_PUSGOf^x0&GLE_Sf@?sq!R!BG3w`y8)M8w9%2EAnd zdRD`4?}l%y0chkppMCJYi}kJaO|7Syx0hGctkYhHv8grXI5541hTl#ZnRDQgTYkuE z-TezkHsm;J;ZyZ(=#S2Yyo&-XEX%n-&?grkf6i)EINIEGf)2R8 z$JoF?X`rWc9@H;^vI4 zpv9$6K6o5I7_RTR)HtZXVXLdOjdLWG6Rj_VD5!s*AK{q5Y!8zi|H86NJg z{JP`Kz2ezwJ>E9p2$$v+msiUA_O1rcy;Er{{M)*gKL6n2_}68_a`}`-F?|+u4mNG! z^+RkrS9Oz&yySOv+)_Mo@WG1QJz0cT)XUy;l-tE^e4Onji>x6BCpE`$7BlT5jxHIQ zzpATxsz*P6>*LR#h*)pwEWX2&tu1*=mDPnIv|#o#(r3)Mg<{2j>jWmE9m-_+U4OlAG{b(M<)^O7?gu6R*JqqLBipPqskh^?z@g z?i#Zwi=u`MnQu6iA}5O!u_TuMu;2sE6Mey7I=lnh)6ys?f(?wizQZqdyKj6@GA-m^ z%ve9gThEVmU1kY|(J?X81Q#IAa2?Qii3lE?w}l0z9x4$5^3#7HW?if@lc-Qh(oU1M zjvwbb*%^*jT<#im5bw=qIKesDRO@XmcbgIY(f8y7CWP&aneoj*Mrp*uprE`t$V^L3 zet0(MP&k762mSlCmLEQ#!h8m2r^8uZCXNQb-9;NYCM8z0DiF?++aB;TXqQh+=A6a7 z_q4)N6^1=F=1UPJs~ z?b84Ir$&i#?eYKnSC{ZRivKxP|8u_9sJp#O&n@4o)1_wQYw&RA5&t>;e_v`!^TqzK zNMow@1uEK+ar1re>D9&j_Wk!$ORw=o;JJtuRq+)=7S@4NCz)~gE^HOKYmzO&Kq7*G z0@M#p2~dnSL?tJJSpqp7z7Jxo*sr_}<@)|Vw;KK(+{Ku@ON7pvW$)X1=a0->)=O(| zt1SfjPS<79+o-OO5seg5e9?+G+M4rq?K!39NZr39B0-5qY2B5s70lKB_IBQqt1em8 z+_5FcG59Ha@2ZyMTutpnSoXt|?zM@I$0Jq=m%93b|FucZ_UT>spg0^Vwc^pKc3Ihv z_JgT=zY_oM_@1ytkX%&VD$_P!01`ioy_AdDjZ&ew~ zAfqQm`!mf1(%Xb{>*bBHPu!WCN?n;!VgCAtVGLe>?3VPj$VcQL@_4;ohq zjwP#Mn@lr1v?2i${(b0^;BIC>5~p&KU#hHeU9)ibs&Lpga5n)@n;hj-;OV9^RnM+b zHvWIFca6*X>JPpdT%ilw!x}m2_B=3Ho4Hr3o_Xi%S>Y_&>j?rzA#~>*!!jm|{Uv=@ zgf3Cvp1=C+bW-S{OaW@eEzFf4>RMzlgKjb+!Lu z2&05IWyKrmq3)%cfToBe63~&isEj z3UBy#;jHR&zt^H%)$$ioE$%1y4r+wH`Z8+U-MxF|ji8p&k!;tZw;So2@?ku~C z=RUlAHUm`6Jf^$Pv1~EBys*LSq?5q|C5IG?3`_cC?)DbTYj*j^U#Rx~5Eb&ul3HKJ z`Tl6g(QWO`mZ^^-%Nw-vmK%PxKm9Y?U96k&fo>)|=G6XFKF-R6p)aGieAOz|2&WMm z?EbpHB{9Tz-NtL{b2Ty!ZIiGMxtZ|fE}u=xvkHF!1TCAE`LSW`e$Ktwa7 z=jhey>ojKHJ-W-C%^uvlr_uGO?o}F48s)oUcTIu^wOQJxVfiMQ^0u?ReKPv~S|-XG z!EigI`@+l+#10tu6DNN#aUi^$Eaj?Je~Si}C<*qs@}NIMX1CmJM{saZ+alfCq`>1z zadEt*)N7zG>w;TV+m;(o`J_dT_yT-kY$^Ns`?s>aH)W%v9QnSFX;%5U5B*~-~p-(?R9HOl)+YZMIC8Aca=Mw*}L_zCreG`LCGW9S~owdwK#+!k1 zvT~38?(y#2sw&H)lpN@`%lF+n_UYGv*Ar1H+o1TXGV6m5b!<9a z6vcWwkAJdJkuLL6?@-zDGd|~&le(#^Ns|e`lz)`hY3FfC{8;_>B2=K}jnJay<(`&} zS%Vp#d=hsu#x~eDzoWB{HDeK+w6Gg@{~4g7HP^N|Ybqn=Y0&NZ4}#6_7&G_Wr(@5d z{i9~k*W#Nve=;H9Qo&G2)`r5OD3-gY2WJhA7LTe`&+&iZEpp?Df7fXcu+zx7Anxmr zjz#@Dt$xdmZ+z&>4zS%ROj-ztYILTu)My)0dbXB2#S|u<f;>ac>Id7gF9yT_iFa5b~}s*IcT8AblddY@0FLDQ&v=4uLW0 zqtc>HRYQpZ+l>c!8#w>OmAiWF`|KnrHLTmNCn3Swx6jt6JM*G%T=mHI;By>(qN*&K z*Pir<$ZiW)vX~zo2}udBzUk>D-8|ssUoE(%Z67A9oxQY8wpI7~tId&ddYx9w$AevY zXOt@NG8LnhM=Wn-Xc%op%8#jbo|(?Ek+7?Lo)J=c%wjoAVv1fcr(cVP<@ODc^ybdW z?dr}UMkfWmVhle&kSYBocKZ6nCedlTiAnw99qrFWB71)<@)=M>3TqUEHcRV9waI9N z#M#a#9C2Ig+n~o<@oVY@>n$jJ9d0#|X}*{>v1 zh=eIl1Wx#)+j%uG)Z0$A4oKBgD8-**0(o2Z#CkZ&TnbQYL!tEqbX*_AC2c`8phxkB zw&_n*{(BJQGKlV1fyAENZ1I2^X}VpD+1%RtUx-s~GO-LGvL`}z#Dne*eMI}mj~6tl zOViT$$O8sG7)g252C)5WjbQd#h{AIVn%;VeT}tkPxA(xaopD1Y!E37)`VTI#$Fop= zef(>DoMvV61e_&bd-s}NFZlgzC+N-O+kN*`BCc=ItKa^sZ^^$FoYg+fWYaCBNw?45 zH)^WRf4?Z6zJG?_@Zf!pWaR)YW7!9d8UyK;RMfBb#B5}9(Qqu&m~}F5P2$KAE}dV_ zzn=U=B;bz8IqTGoyNu_Yo_h_ZFeV2o8NKGzETE^6mY(_|YG`C<^1i^h{+EoLqfP5F z2YxQt^uf*JHtqQsUB^2i3Pzc3_aY*- zl{D6ix`h0ev2)RVbsNWIhkHkVAsigL@+0!MGh}3-Gd!ie+iBn%=C6p!l z{3~s-(?0qo)TeVP1NX%`ZC?FfF2L(rp)$_q&9~kk=hKp;Efxrj-=dY3OfQu066GBG zU2DtVtz$1Y=(9CyuCxdiN?*-5ME@eS{u{Fp^{~kGm^%`$%zg$+jTCR7DATeCG7fK! zuuqnoEaUsb7fnZ>OijL%bj^_w9j7f5nH>WX*0-PLNL~tC`eC+IH`?@QpD|s@l1)Ep}NydC)U5G)Q(0i-6}e?pNk`>Fl(#lIhoKRgMB8oZPr?0~7qr3uMvZ%^+3 zM!#dn4#3D$pr$6njh#CGdwdeMJ%XGPixBjzlPd9WI)```dmzzng6hjy*AKp?z~(DV zOGhF!Lohobc7CM@J;We@c|mhOu%}L?*frTl6Fliq4J>56#%en4+Kk zBtx=m38#7;>;am7ctW-5*({k47e^!{{;*w}wFUn{-Bnhfj@YmO{3?dM2o^ESh z%kvGjK25dG_K9JB7eQ6T(%X`Z#C&Hu^g86n*@gYm>6Vt3NDR~yFqF}NWn^cE5$z(& zb*lf}d@nDymlwZck=RYC$@6)iNbPyi-r#rb^N&NbB<)$2aesBvZ{=hg3VkFECU*IDZrzZvZ%icVuk>X+Tgh$F*KCtC{n&nmcH|@|ZVNwJ zWLGuOStYNR`BqHfors4kS721eib_&XxkBf3ZN#Lc@|2R85anVekDpDey|*HlH}|HI z%d&d+>rM7FTC-%i9o)#uE5VQu_9>a==vB(9_J7YAvsS6AtQ1*XI|~*JlN&jBGs8|9 z^zF*$`nqa%ynUOct&CEg_CuD&S0QFcJSU&M-zP%1wRtC=_o|jczr}I(4Vs zWEX3m$Xe0x${z+ge+}WtM#5B;!{QN?=jCrA1!AgTf zX*uJ2Lv^p~{4sIDaSy#TiVp>{W@`42P5%rO)4Na}!N!(s%*=Y%$u7<^+IUx~!rtmj zUIoslhOTVmV@Zgb9EiS|OXYV()8P>%f4Uy>P#O|zr|i?mzU51Muxq|{WhOcmka;eA&``~kSkWY zMND&(qpfuNa%8Hy0xdmtnJJt49;352=;dE8cR{QH}W^=Ns$HB^7+f3E_JE@xjd<=vvK^yXLaV7xe153&AIB*hMDRH zm6Y1rE^P@H)=sJaOrsrXQSy4nxz=|hh)b@bCP=S&!-0T99=tp=MlMu~RP!w6vKei; zG{baE4-KcTY_m(v$ARe^g7wxB7Hv>(;J+IoKNP$YcMO zGlh+_Vx__Wg<*KcBr3poTd!W`uQr4z6Gp7_OOFmp0_fw-~_J?G43i zL19dC;_;n&{}~U#MNR_G>hE}wiV|f*e5DCX|l-rW#JG_O6n=S5fpiv+M zF-#i6V4Y4<1cN5S_}f2yvO4`W^IWI*U6`f6KPNT(ym4gO>%xT_U!Yb1^oy=nAw}ol z!B+#5tjqE@TLyf;v7SH0F?|Q`jx`vKgO5s_Jy+z1w#VlfA~f&Wa_C5^?n-Kcd!lvW z#lXOS879U4Y2-`xExxdRy6XE&N0e~yk)Wk4J5t{CtsV6Ttu&Z6#b-)wjq1mSQM-@wc+pOc*X@2FuD za&eEm@R41I;tZtKjk&iryqQ*cHQ;{F7Rx#$7zA#5yt!rbd{$WZvpkpFpcvTxI=+bU zHU7J(g1ZS`4v}?SmpyLfCyw|!mVfS4E`mnsM;C$O-m2EgsF{VALpHQLN5K-UjqK_r z%V@p-MS9XL84R(in@u9Zxs9NMFILzOJF zo+VcQ@Kp{@3^AnFytbMYaio}CX}_qvuGT%;Q+wzj04u2)3TQ`-#)r!gSbtkYS)y`E*3>w~>yi7RJTd&B+`7-rZ<(GiNT3>u7{Pz%$+7 zrc~2$`Ws#)c?)%U$xy4dZ7zBIjH=F)X*2DN-}OSbH+1O_269#8ia#sv%@|`FqB^g& z&eCR!pjrd9X`+8iUdcz{r6ar#r|SQ8?6j7-vs?3o{Jun+6CI1vE-x79V+5Yp>q;>E zxOt1~EyEXgnlqJQ7I}6GF6jHacDmFkU1h%@$QY3vww#u;YCGh-o8b)?)6@$Zz2@1! z=CkFZ!>1Tq#b|CvrsW<CIY~-7!pk281sji>wQR~7+4A!`Lxk$i&mT6Y?QTzc{dLE^ zr^1>J)zMxe9~ExSu%&!bTE8^OT#~w3B(qC?<}^*QY3H2 zf|n=;64uj2S2UfNI42pk$Y)d5@G__`dwAkR_?vs$;@mlDZ~0lSsnpVBAJREvX~xdb zNm1FGTT-W^O#7CK{=poTS-ZBw!s~(UM!zq+v`Ko5|ft zc_u&X6uWzDdwdgDL3-mI$=4ZyGc~&+GCNmw6NX_*Gdv1+pU-$9@3c`$?m+*ft%K%jbG>z z9UW6Qi+xCW_v@CBj;xGB5*ZAs8#&KBA8k7@yU%)td(%VRmj-@Mc=?W&$Jpdw+Ij7W z!5`kOE-gbF1}^d-`4Bu|!X$|}$DI1SZ|P&@`tawHZ$jN%T~C6ZOay0Vx{D~37_b88 z-@U*MLIdhqXtQDp3$bhypJ~A9TNTb`!14`TCL$PBONWl&n?2#740(jJ_@wEfu0}x& zmyB{mm|0+a;AoVNDLv-*5OH>W{HO~X6ws0A1i5yfwrsN64I&GC*lgtjes~ffNJA}M zy0^>Q8tpNdIfU2)>o3ozJotnYn4n;ld}G}!Id}*8z zIhY<`jB)}hQ9xwZ+C~UZ7j&gpWe+S;Vg_tMBT~pPMk`Muf|bq>6S6P@nTkfsFyr{< zJHktb9x$hQdOQxSlR?FI>!x(fIf$51ktYvAS)9z|*zDIzd!VTU^9n##MrGwqBQXd3uVpM`_#rHoW84y(JF%LfoF|K-9fYf!8 zusn!4(#yx^NTe&Yt6De`y&L}|V$Qe&Gep%w>~}D%VdS~xpElyrvZGQ4Vmb_$UIXA; zzvGf(+<`B9H&%W7^2Jb3uYs03x%~5X(2_vt0lhh_XmsprWjwh9a61)^HvC{6dbSe; zq|~6LU!nyaME z{3qzsfJ=UY?M`4|Alj;Y=-P0Wp(xE4Ik|bc0OlQ_Y++OP!TT(>o>HrW13NZ8IHcpl z^iG{RC3(+x>If_qhwaZ-nn@5t<|LDRW_BaM@F1ONcgbUDWeNc@SkDSLR#XZAXW+m` zRW!QM%D-sT;mBra42WbLW_c)v++d3M$Q14lg*%qPvvS`e+s zm!ci>f9)O(O(f1X{7R0n{ceDe8Hv}nY;o9}i8N;;y9S;BJe4bT>MK;Jhlos$6Ghj7 zo9d$fjdL&nOKP0bQ|J8Gl)DIP_yyQCXZ)$zK;*DNb6MSVV^w^&_e?il)7CfteRF7(?ix*W zMXzXuifLb#%D@5RX0`3Unrrl~r;7FdSYIjfFE5ytN+?%fzd!X)d{V_*;mS9<3OtH$ zA4r70I;L>^^b-Bx6`O6EGiwV<+f?ep%2p4&KU%}uIcuvOG2G{7W`8C}{qWtK$V%qF zZhIw|UhUC0q=-xB6!Uv_*`a^yo*mI1WtTe&LN2{JbaQ2^f?;{OmyJAK51)GSt)rqc zm+yc7Rz~Y1+h}7j*eSE==7Qjd9uot(#LW3pm8`LL!==lEW9zTHPx5^LYxX~pu@A-E zJNa(!R_IJ}w0UfHXN89~oca|-WfzqoOU&A?b!`)E%{2cgZ?BJ-w^rtVcg@;~Nq=U? zE*4%VAx8HvDoZwP*$Q)!wk`BILy?u8mkVyy?dH(w_hs8cd8)YH^Uvi>zF+uab;dgQ zY{Z0)-NeO4w+eK*&by_A3I%o*)r|0R|27J>5Q};zE350GXwCQQI?eGj*Pa^ZsP}!4 z8yMmLJm6G;ol>9P%Uu1X3BR>-xjUUzBzrM&y_ zo5{b4MSA5qhkvjp1cepQWF}ckNylb&-kwkz`NRn@_!7c!?i7<=bO0tP-$ zr*Go^v!Q&k@@#UrrgQ3L-^9ySe(gbni<)LXLM4`1BF2tx9b{25nE$%7_Sc0b*GIns zmE$oFe8rl4msRM(V~tg7+PyzmZ|7tml%o5<~;&Ojj zQ861OzNp$p6bZ+@I&`)ldmIZ6XG(Q-bzrhG9V4Ot^oP(67-HQmwEV+220B8cAe%|{qUNF$sz3p9|-JGMyd3t(Raf3cU| z)|fF88UgYO?Xh|gN_fi;5<5ZCr_~o;GMAj&YE5Q>MB8y3vDsuK?zw4YDn8f|Cig@J zgGCd&8jhTJkP^WM=+P(SqI}a1D_4#C@-;cKWh76&@G)E?#}1}32{snZJ~m}J4}C`? zB)`P(y{s$|>m8CX2>W{h>KuMTZHDw1=Z~m;i-TtGW6?a`{o30aZUD{rs#Z_j1yMP` zfF~yz?in@xSP(}beI5WBCw|vtn&4~T&tBJj!=WYd4>2g(<{U(lyb6!FD`*$beMJdE zB=jbG$K^4CmV{@;bwjwHHxu*C0p?&UL+!IiHWRlksMkm)WP?Y^xfEJHpG&~wi6R^> z5op8!+5gSv{}%--_Q}+`yE0POJUcTp`Oqy8{;)4$TTjfAHKQoj{}~cwi@s)<`|0UUL4oIG!;gKa^1-UB1dVghr*%9pvK4n)}x|p!Z|ba$J}oD$yiy-!k6x9T}%T9UPYk712nm( z$DInL(hyAPvQZ487&1F}Fk-|P%%~C+L#wE+ZlVSP1B~!fQuFl*q=`mV?)dzyCYDLS}`3o==>FxREOMA8=BTTc^g|f>Xe{ zg%Yk2cg-z@b4Lvvi6vLdP3aM!q0|!erPwdY4Mw>7FbhCij2r&!*T+|hX&!`8_=81w zdo)2<8TT{LQ=X&BsCsp__&gXc#6T5r2tEot2I67W7BsgDgxsqxcK}~_zNF!FN9EJ+ zYts~#TV0(-ysVDzq+N)lA=MG?{rzXwtvrSF5_Q*(J-uG+{lUQ(Pznbu{w@)!HOy8TOr}=KrNqVXo?mn3_<6iE|Inq8>q4l|f z&1!vay_+y+PY!o089#Pie7^c>%8a}9#HC9~*MA;3n0$@WwZ?u?ActnaKsmu6sB(DE z7mklg5`r{32Pbnb222faQWvpUa@tVYEn5ES)p-%tabeejeXdnHtEwquqnl0guP9jV z9q^clV7QRKznguFHr)f>KbP2ty)G%)&`2gIA&)O(!;9F+sy>aQ%15gN{vHWyJx`-@ zEUO^&(+{zumLV#N<9untTUwu2%TO;KO8Li}yZn0N#eN3;4Y8S0oe4fJ*Jx^Kay2vF z1imD@|NDi|#L$eDPMOl6fqey(50C4}hNo9%*YwIUC3LJ88uQPTaAHw5$Zp#r@jZ&W z@!TVcv67}2rw7I^cIX8xL?5!>x2j5erhJ3ul8tDb_HfAdpF_J_Ww|>O6ql`EREOAz zchWqx5tP@Y{lL66zkA!Q9o$bheXNc2%{aJ?&iGp2ZvlT1IYuF-j$Mpw8#D_(Fz83y zWM*k^Qt1|+zCdxm;bn`yt=aZD-{8LAyBRwFoBwTcdnVD^Y0;aBiVx$=3>DjbQnV_{FZc`=UkeuQjL$km2jExp zc^KQ=$dhjwS@kq`=N+fE7x$$w_k>2ir#X5(^r%?WC*_2D%ES7bZYNdKvHxZc&Rq(< z$+p>9i>BiM<5F6y@RF7iJ*Q^2m!|mQtAr233l6tO{DVY>BMnm7HKRkT+Ui{A+XtnF z6?v=7Y1sYvS_O@5)02zpw>`0=n^E5MI=t+&;3Hj{hll!|YJbg|orqDdem(w~;#^dy z((yWDgHE@j;ku{8#hkf}t1rfW=unr=t%?7|#IWQX;ZqSJvvMuO=ABmbwp%B5&%mkZ zfOE~Z2nlw({Xib7bad`bMz9SS^gIIDCC_py(e!dq(8lJJw+wghtc7O?BQ(f&;UeT=s7eArS-aAK3HqQEVYP zU;rkQJLpyKS;j!9D+3=S;)%2KX5Hg)3zMrPASPl1r^6jFa0*|5D-DH^9Q>x<#PRM& zLkJTe^LuSq?1NO=h>1hiF>|A`gL<0~m2d&C ze+f_#ipr8e!-dsfjYO_n41iI80U8Xs!)Fnv$tf&jFhRrlx zDDpDRBqAJ?f8`_rIRbgX;oQn=B91BO4$V1o zALYfIk%Mjt^F4eTJY^a@QYk2-oQ@=4zb*^2c~mwaSraLJR<{UaogUT0qBB63IQTmbGS%Do^ z4xbL14iBBc|02M_KLUiPS*Z2(5`!ClG2hp0Vv+1hY%~g+yr-g0)>wlwoF4lirFEwPE#B>cWTjS!_A+o|-(s+MUchkvRf)@DKs7<)Z33`t3F zw0*?l{m*-5Uk7zZ7t@T`rss9pZb$gUf0=vO-)cN%*y;_{-OE!J816(>v}ATeFt-rT3Z>!w9f z@RKx>6r=PuvYof)y!qx7W`g%n4YVVlZQ=Lo(wv`BzKEmnVS4R&qtnuMJF=lfH4smN z3CnQHDx5zjId$zb#osZk^rOwsM=5fOms%>yNno-{3rg+eg8MWjKhvvp%XAQNEs)b2 zbc?W)`!^k0NqP*egT406V_G+(fE{!_tGR!*09;4Mb|#H}scjf7!~vpJleblBbw0H+ z5MipS@%PD6r_Jp-F9?WSI8!VX((e*~%DtLmeuUcrl3*><$V-j2jYD`nR$ckSzIOcG zU7k5!m-a~RS6e5a@Zt)W$0n08E_-sHIT zs?q#uSv2>8 zwtkSYap0S>LOGuAg!cGh;Tm-uEOHgUdVL+Es~kSCCCL))rM>x$nfz@WdWkVJb>p<} z_OkWTy@b6WG=pv#%R%Nw8Xta5ST$sx;u7Rf?jI8OZMiVu!La`4UCa%!&_@67{J5$^ zII4qz-+5LpwkS0TYz)qPnUmc(FDO?to&S;@HD}KS&jB?AalRs$YFN`_YH1U-bmI$qs%W==MC8S+|D;pupKptB>s2-3wORKfqY+ksWG_@AtRxsIL8-@!8ZfR&8 zG71YfKYkd11tz#vryS=AkUtFAeMfQb$JY^9Hn@JeVVdUzBQtsEjWQ?N6H!64IG-NO z34xx;$e)7raR8GZ1XJPk<@q!+IsGqY`x~N!fNtht9OokRuBpIq>7aH`1dB8ejA+T^ zj8+G;knSHQhM@9BM2xWR6a^p~0(edY#2R8b2C8x-$Uv;mud_c!Y5GqPGXdPyCcIq& zjmOI3uxWfwf#rTOlyU@cw(oEQ3+?}Wgs|BEfuliNRMc?7YFWR34Z4DlyArN>*(qAz^vV1hhCk`7vg#K45i_byC z1F>rf$Z!Xk*@%AwlrHL8D^J!P8YFc>cW6EcPYwj7uN8IoO(CL zFT%K~<=uwaifKm%B?;Vugjk9@_qw?VNhNwvFnPJgnf+*NKKYScWxl#SxR+d^ekaT;`mQ_jh#et`|qXH zpOpT=k~}8kQsxZUCAwJTIuBbk*pc3%tILWH?yg*1bSIxo>S2wm7MV5=#${XrEj?+QyS) ze|+sl+Ez7M{B;f!lT>eX)~fp3ld3QO4$?+{l-?9e=9MwJ#lbNemr|f$G2_JLfY0Rz;arLi)OQKR$BDrV&Zo1 zwjwcMo5Z@mw6s^R;>J57E{=gb@;zF-Kt|FG?X>gk*H8OeTd!+plo8M|t!q9dO?##O zq<^iWDP`vuST44a zblal#!H-hwND|v_UoF&M@8Uvhc}Eqnuc98;Val@>g`^JS89$PtcaJ-0 zb;E58ADSJk7ksc5QMbhr>pMZp|J|MWfqAr#|E;~)#jw6x+oT@Rk;uV?OnN~K^FsF@ zX$er95?v3*fOZfBjZTU6{C((3+JN%}phs5dj23}<#2Bv>*|&`F#Nii^*(MypYP@p= zFv$fICrpIL%&x(x5I9P}8y8~76+XIOjPzzu*u#|RJ2H8LXi4-Y*l6m0eM1gzFh;m@ zkUE{)6Nem60s^}MF8l8mjr1-5Lz*e=H%%j48c-` zh@GKUK%Nbthv8^6{q*gX`0fqx&?2t^EVPVrzeCJ(c5Q6}9L{kFV$=1+4iOjm5mG6_ zbr-J`TZa+ERE9u-5K1vo&NCi`Q$q+=HMN`hJT)}b3RAaZB#;9p1&=E&^t(pgR$hh0 zrVmsK3WWX=oE{nJfI)H}glVE8EG~p60VN3B@z$$ZadcS=VeV(oSA9+mA$i^xi^7JZ z2T@vrr!5XPmx$~K!G8k51N5RpsmbUKqm{^7+c6h>iKs%(;sM}RfcVi;_6m#1ZUDuh zx+h!yE)Rt1Km3BCz5TwmQ!Bi7FddZUjc3Ott1#aMjPFk0l28$7Q}uE(#5*NJ#|ZBtA^3-p+?>sj?c$lv3_U9{UWYAY70h@a-e#{bhX%D5MhYM>LyUwIZl_@0 zvd)Y~$_|8z4*mxkDWZ=+p`cCS> z8w(ExM(C6sEFTV2F3*++5B&<){r><#cL4Yo?w5z|pkD()JhIOKJ#B8ZsMvp3m^3p* zajKyxHaLT5WpWvFbKFbJh-FGYTl}`^0>;)3SI~H`><^J>_Z*jj+~(w`$hH46c3T;ZGVd$9Ki# zGas&I7T?r{Vq0m2{!a$MBwm$*LWldy&dPCYX&hH2dF&9znC_lCE0-zthoNQpg}aw2 zDMJ>Gk2ZTvcvr=I!X-(3)HH+G>C$PKQ$38m{b>sGw{7+9uex7(D|PLK{!dJdIjl^U zVM%6M3~4l$OeJ}5r~spt>d&M54#RBwx{_jfaoKB@j~y3K;}ds;Qf!u{4o`R8>epgi zHF#yxq>VDHC>;WJi&u(cI|+) z(TXZMoO{{y@5hO*>I4buVm}YL-!5cCXKP2dYnM*VyP~fx`BPr_^3xqrWf`j6@UzD_ zpdOmMqt1po+jUbh%gWXK(^IZbsMh4GKcYsTqCZdc8(s@{D+_pzubNjt#40GH&_R87 zVY*@huS^&dfNazS>0uf;szU)vBc`_{t{`foNuXot*YcsbCKZtA7zAKlir+^aKcygj7{D0_iEC2}0=Y zAY6x}$rk{tccU@Dk>b0_5U+Id?duFyZa4bD?c|Lfn8QIo@RYLyuy5y$=&&`>V7~_G zif~X~ipr^DZJbSmWP?a_8wJL=Yiq5&cCj!W1dEHZz+@YOx`gfsVFQ6G=Ow~hhb$kviE!JGQO^19M*7vFXaZm|IHbd zERDLeogTz@gY;_8LxucY#8pAG1X35sGb{r@-;~R)GJdP1Vi)Z1&Swl6qh=xX=rG?l zGun-Y|DP=#>;&am)V1tA!dmq#>)Ts=uK*iQh)!y`8APqH#6)cOaP29B$OhW&892S4 zgLKK?{?9DNVr^}0^-6j$^v}qg7G|C#T3Y9m$5Sb-W6BZ~u`g(|7GTu745W63l)is3 z1Z||M`1vil5(ZXd-1t!A6FQ1(8Py}mK5np9s|#M zi$Sz!aidQ5*sxLHWIn#~8Du9lkfp!S^-t=@{kMN{+&WG;YhzUIxAV_R~qXOV@7NnaD?sG2Fnr zGMf6tQU%(~4bRIQjINapA2@XOq|=KkGr2^El3d#pe@4?dA^VMiWL29a+C_180t)h|-=TdQktq?2nl7vGvqp*0!Y^^b*$?FKE$gThr#8nm?rUZ&lrObre6_%lFU`ZQ67|cqnCBe3!(`bGl_cNvyos* zoOO^z<$yzuh4u2GmiQpajQrfLbWThmm!k9Qh#&rCAC>v(jA=BjY<`H}DBgMbk|8bn z{=Toe^t{%WYyP0?^DkPmTW%gf-i znmO^;<*jm3^;z2b?UH^<n6OFh3JYY^v!1u6AyX&m_e!i;{4G~_4W2iS*$U<|M*AMU#r9K%Ek*(U+e ztOGDmT?mMNX3OPJDb&AqNLw+maJ`|j6C$6Erw$uN9`JWsQ%;{Y-PNXMm8^o_70dN( z6PqcFEKRL5&cq35@nAB-0xO{L@AB8kcQJ%gC=zB{NdOsv;h29qbAyo`u+%OAi?gPq z%ans7BVTyKf^ti}w|(QYslE}SYM z6PmLT_w(xjV-Vd)y-Y)zBp50oGE)?g1Z(YUQDgsErNyKirA=%p8=$El5HZEC)N5Bj5vf!JYH~oX!KLKM(JNVBhUDRMV?J zpAh1w(rRmOX=(ZE+mrtTg@Au-U+h%Ip_C|ITzM8qBGDdGx^`h38hY8=VtRnt?ps?s zy^8wgshIQIN}1&pn?2FU2eJHSc@X1V-d!b?zcYEkdnK?8KTqO%^T$-f&GOf^V#@F8 z1DSQCR<5>%j^L_lTJw_3;;?e!-r*C-H4d?#wBLW?+uc{Kg20OmXAJCWzBKz>JagIp#CVx;*Gy0Q@KqEJ z@Fq-Rc~lrg^<_BXc^xivuIWB{^Po)GE#6l$T$D@JigFpzEF@z@_*v_5+|C~v%FbBq|m>E zMhmSJ#e|-hmN5U%c~LNK^t-5@kC|2D@1nE`O`fr8n3Yo;dt*3zD=wHOO*|PQ;yV?z zu<@b`C-eSsH?VCrqSM*)$-1_jomgV_x+i5QqyCNRDinf^+ksK*dZ@H_iEkG0%!%Mi zKhpX#$)Duj9pRRSL+-t&E5zSfUMntV7LHPi`y@tVhxzIrXAdo@&9i=ekK+nC>jXK| z^q^Q9WvPfl1^cZsr_!S9>W}au6olVs((*pM<5K50xyVS3mr#Jp>mpC=lyjzm;TqyK zHrIdp1~vb?kYG^~`YcL)RzqOTZT@Nk#%e7&S?&J!HW8nCjD12i0+USp-O5npPKE_k z2)|7J@K`7u-ngjWuJC)xlwROcrdKlk6={PMCRW|;N;;S&*Sww|eyiG4v7Tn)obd#lpYqNjRew9+^&yZEv= zWT;(4{u5tmCmh1ts4lyUlWt!L6N#D?84gt@1XS3fg^ zMwQ?15>CVXyxq7YjPV@KBx_24V9Td5cYCm6Y(4IyKaoh;GqEznIp z20k}3g+?R;u&jsSGA}5W!QF!7i~jlZ2O^Gs#a-F`FZa=vgi=Fs%Iy^VMreT4Td5H@ znfW#T`qF@&4&>*6FYmk)wha-^pP*DlT+$AWr}}e$kK7y_-=l7Tt4Lpe@~Q%{Jj-hU zvS&-2vq1gI2y0j9+(Ebn_{cLDPVO|D+MPzaTg)PjRA@PV{yzG^`V5kh>AY)%pmF(aC z-_bU;TN#5=qMz|M?_mbLI3*gU^E1}LzmnI+BvYs=rBPyQy>)OfZ7W-+Fm;_%*l6BN z97TEh@&A}?w^)8!grYkA>?M7~JbvDfFyoM+;&G}cmda6+j}_gbG1tkoQnNGzaC{gQ z63xSldqODrbOz;8a@54cgT@y6->Ld-1ffx2AIlkRH_%Jey;m5YEO9BG){=D^_Lz5b z+Wn@0du6bQ{BpFAc1c5p6IjG`6)`xK>>;{o*TgD_>nuu=^r~jxG-}4>K)9-sbT8s z>zJb?wK{acdj-Ye9+(3}tEL&`L6+lW+Iki-T}E`20m*NOg%ms1idXAI!sKV)k9NKp z2_Ji7-+#Z_<^ei$L4M~AEJa_MOF!Q+Gom+Bf9!;3s=B=E9bT;%gv*3nGxXNx=)a%C zOzSy0KeLeeY?!vpP)R%Y+$OP~jQo9(%*c?Fb27`3mY$5UXRIbp%51mr;&gd&ZIaZK z!70hd^u{kqD7n#@QC8jqr5n=$Ra+?QR0(s!!3&s1dkVDS;ZCje_2ny5MeZWXm3cJp zf;{bL8VBeqe~a3MB{y41Dzj?3iv>kR?XiD<%H0-q>IHKSQ)`MDOf zAr-TL(A?owCl}AvsnLgr=Zw}2Ury-Hx+4m?I#GAaSjqTx$tr*Qe9o~z!^L#X^s4S) zto&s&;_XhPgQX(RY*Au*?8J4zSzQ#=ipKN|FFbp6WV6A+KI)wu*Vx@{iq}$m6a*RH z3VyKG%^SY>Q<|A5nQzZ0I$7lJxazrbIo1nJHsnuu^2AgM0~(&dnC zp(d;SWAmM*Z74p^-sC>ys*mn--L!2IAGNW|ZT zk!@6$$v+BMFpeTAp`ccBW!3+eK!Ya!kupBqxobrFKu4-{Y?<1+YFb87_>?>ag1BYL zcpoYYUw>KRtSDvAm6KWJz7ay-*hObk7s|(=UHb=n^$(-MaEuWddg5*7jnb#PPf<&0 z1|G5I_T8T}G|!e<_7)n+sWgC_?MB4MTFRhxa{)>1hV@C*dMO&#$k+1~ zTT=5t)_dI|*fUb~I+*opUwK~Y$*%b2UPVHkmB+OJcUm73lE>zRzKmBoR&|a+Q`7e$ zyoNc5g!Cf^AywE*c?lf~W)4k$mBE)*$0a){S+u;5vgbm$-}%cNypEF2Ues2Qry)}G zl`PN<|Ewp)J>!N6G7RbV*G4|Ho!?7V9GPX@Tym5gpM?8b(jG+ zDU&qMLF_7?2uGsJ>sJk!Z=}`U>n3h15xkZz}?Nhp8mBA{Y}C z<_p`aAU!hx_1G>pW943?lFnTcOXoi-7NL(Pm-|v5n($aqDU%?GGe?xkE6O?*#r)x5 zfI3>V2Dbq^Lt0e>lRnsJD99>);|#|{4tC@(7_iW>DVkfvv`J7g@~LNRhswQ^Gn!sc zy%c$RaJofYjk7Ax%NFbt7DyuWf!uCJfDM3_M(QJzd0AnI8bt~WST|8_U z?IcUm=F)MTlcb%sBF^p)xvMH>*SQXEG4VV;b*8Zk@zc(jvs3)-o5~))H}v`7Pz4DyRlQQ$Q21v;(EhU7H06*mMSEqZ{|nE5Dt( z!oaBvvI@j76{Ubw7Lf7>5Dp|fTShz5aj6NBZouXnvO-R~qCnV=+sX~!O^|MX7l;XM zH?BKb%STk4u)X;ksOo~&$_+?pI74Le#+{t=_4$pPDKJ_fc@hp>F(aQ zuy5`TL2Ihw$Mgt114eL5;AudRc^85&cyCCk6*A+2{u0!oG+7HDpNFEtYX|yt9_1(i zdNx3Y3Bu%n&=@2qsHX7nwZH#}K~Z0rY_X0aB0fN_jFhtlCU~@n62b)*L|Kr-D5nmL zIa;4B{lUu##3KOqkVH_DB1{rQ7KuD5R4tsTP**GTHShh#xRAd}T>@sKGp7uDMCD{}Nhy&|x-ub=$Tv~ipgxWe31P^!pJ(3S`R8h#Igb1^Badv)4E4-ZqS@*My)B+H zkiU_-3TMxa@_yKMrcb8Si7Q`351jZjqX+0$FW)nu9tj zUtlptS8yx+533jDU-$!u>HL#wY(FvFn+}LNlE+~{t&Zb3IBS|LVXSk~%)_CXtLS@E z(p69vaQm%c$S|V*8z&l)rwvqTN!XKGJk>EA!GM(>8ya>d>xLSCZZI5D5L7)?)K6SbRLFmm$A zWH040D}HVIJ#WY_;+pl|tp6KJTiBJfUp1&Jo$5Cu;%oG>G;U9aT;aVHM6V}f`kZX7 z;YRjqB4hK2u!g(UbWV+zS$MJZG_*Y^ z=i}-&#wbNetG+pqVgHl0av&mJI4$O;w(UR)b_>ee1R1w{EhRU*T%DiK0_WFt4 zyVA^8eT~DE-Q=$3DFzN5Xv8Fsh-2$qz^TT`d3a3`4`>Op_+O~_t6wm7ueioYmFj3X zj<<*_q-(Rv_HoNMYM>N#N=lo4y1#bI`rv}md(xORIX6mJVC{z-cIi4^0EXe8b+!V7 zw4JD{y$+Okmhq-6zMn27JPnFG?l6}p2vU*oL`R=RXmLs5NP!3rzT~-&vUhS-U4uaD zume$XXh)vEjY~Sgj&^2iHh%D;49fLRhPbCKh;0rwKBYaJ^GE@jre@knQz$Wm|jwY``SJ zH3xgL-IPmFBq`yRfCa{cZzej`C}pdqjYW8rZJLXcU8E&%v3 zpo17Kdqo5t&g&t@gU!hv=uxTxL~^{Q)7TBlPn08=DR)#h>~nxopc-^)u&{bhc%%dy zp{5y|*9K}g(%~5e6da>vLJU3pW?aHaNWe zdN^xCj7#Lfh=63F-MJsKeJxb%zo6gvu13EZJR#zBJfa`W=m=DCK3hUO&%`k9(Qam9p$ z;Yh*Y)nURK8)OrGD7myAmhPwH$eb9t2q&D^3*f~;e$Sx-nCC<8{+I!I&Oz3^is)Zi z&;xdWa`ZP7Dl(rz9DIVxNj%xo?D?WPM!S;tDL&x>orXT z&#Q?WAMQpW0RJI>L+~{?#SxqjiWM03AR}SEQ3lVq#i&3|Ri*3gqht5_cW|S2tD5?> z8fh09D3?5z)zT6N0gWv%1O6fzV*u4edOYNjT?%hi#)ZV0?wpVvXq7weXirU$NgvrB zs}0yEA?|r-31P2V4SfqDc|AfX1}Mk{Dt=HBM)i&kP+y=Kg#|5a6Yqki=!ft+JK|G? zh>fF-R}ww&RJWK~uAo<7Y zIy%@*xJ-#nAfVB$v0DwFQutTnk4lYCxcyySnyFCsgBRPYM;ERgsKaU$*2-QhU4FNE z?@iE{o5X(0WGMU0`++_EIkLWl$J;?oV~g*yfEYtW<{I!JGoqWSJKc2jtUd*#_?w1ti6<3{_W zuI*3^$8}&$ur&^3E`B!XP5PXH{S0NS0&!N{@ zNR2!-^@;oqY_ZF(c|o-6BxesW91FW7zw6mD8Wl$5=P63#gza8YnD3*A=GjHLiaZ$~c+vlzB_@C6uNG-h%H~q{En4R- zc6s0L`F=OmPg8Nb0jyFdfoKEI?bCPdSbyM>CRli2m`9^feG!XHj@$Y8-dup@_vO^* zi8x2A3Abx+o;lIJs?6N1!GNDAWgB)KkJ#SfxsD5??2cJwfi(HT6!+`nB0jOhWn5)7 zg~Xe5!m2Me#D>8KzJ}vkj>cnE!>~FMY+}HHKjE}l^}LcR=jpFsI~KT7e+d$%7Ufv; zS}3BlzHuoURSM?Z!BD(sLEe@iXJeUw)2jURGiOl5QBJ7l+s55YC!3<=cnh~U zo@^;aO;=cUG%|HC_Toy*>4dWYeCD@d<; z@WRyZ6IYuW8|MBcWyV1_^Sf+WW!yv875NRZ;={Dw2HCRB;=@@=t?RzH%8w{^%kX-s z_!Xt43a-o3w>y}XvMD({rQ{}I{W=|*p_QL=XF}+DQpVkjLlz%AUar^YDSXG@e|a3n zaW_r8i5fSlnNs?cc@#=t`P(jALpLsfd10DwRab-N?^W+n7UB66C8NqQuoEg4NPdbu za}GYcIi%eAoA`FYqtgruy1dkTRefr%QqkRE_n&4Qar!T%9MEVTCcRbC3RJoq?xfH# z@a>zF*1+ITRk&3P)wCqkJmIfL?gkaF=ML89NV+8j z=2&NKp9OtI^9lP!v74{BQBU$!gC>%OLjew#O-{1d`GpDK#y+Ri79gNh3pz6DBTvkvl`4~6rj$4 zmIY*OfbUARZGc5ze}5C=JaXI1puURd>+(7i#EylH7tY2dk5f<{!RoOQ{I%6E#TGnS z$v{O|_h5RP2sgL|j2Tc#L#2KbQT8Va**dNlRDvz)^JDfMF!OE)KRZLfEg*1zTyA<5 zkyq$VpIk%6(_D4iXo#3EEv=d8tR95hh^-h1RYC1^)F=J^M}u(}x}Wh+qBW9a3J{ehFgzVWq`Kul470((7PvWGaf?A4jz}#a z&!`vfaL0|3j*oX=5hC3jVgoE{yby*PKC_`=N|hVL>^n{9L~8o_%E0G_%LHUAR0t&u zI#yI9CI|GOV8cRCJBH0@|7kt5VbS>)d7fM@vsb_mn7S^EtYP7ua5F4+QL+TcBQhg| zS3{JuY6cBtn8?xpQP4X58OIK2Rd(U}Q`gdx<4;)hy281g&n=p9jBIp~W*Ka7uxvR7 z%k)D)))sqDLTv=P9)Cc|{%GugBOOu-#yx~4VNu)XeqoCUN7Dy#o?#k}A9%?6=?tqO zm`sSeUx1tr!Ju@StTt|ZiRL|IMu@zs7e989Y6XUM){-Px2VL9nneZXL5Un(?Rc%pJ zCu8q{d)o+p)*8-u1>E9ui1jfiBNT1IXA0i``Gmr_U8s^0!Jg8-;(>dEBFk9r(sTf*j8q8I={)=bxI2P=+;i3Bg~r#n;wn{;36<1WOoQ1IO@u41$tX4 zd?*pC{!(vz>tW-iYa!&dBjp=>h(U2{T(a^WZ{OX|ln

4K&Q9Ia4(rGT#kK{OG2 z!9ER%z9q^LwxfLIQK{}u9s57SIT7QjwSBbgXVdMN06ilbn#%6Xev|zxK@TY(7$;QR zs$+@kBDZLf`Sif>VLqc(5J_!9Bt}2Y$evf$H49=j57JmW@lWVl#9}HBa!F_=l8&_R z+DRm2?2ZUU^JEbfDZZ}MERy{=Wxj!Wh&pmtGN?3#U@!qq=UwM2>>dfF*`nUwz#_3y ztX*&^HMsXc*GnN9&Cz1a>vLe0cgZN;1n%seZH^znL&Md__b(bZRYt%`0(f3A2Z#aP6IJj?q@gt-wi zk2U#l7&o?f?DP~LXQWj+*FmI{)1!W9zpPXAk?h{IAd_S!Mt#xpz=H?sY#6@`1a8c( zxT3dOyj6U@{vit8F3!B;aG}?yP0EYTnw9FeMK4Vwc8L&{BBm+n!&(((?DyE2ao1Pe zSi>HLyOPr6y_HTf(3Mp1QKR!}(N37qW7OCZAgJ57bGVi!?b4tpml5iK{&Pu1;ba(1 zKWypX_=RAW_!f&6D*o^?t6+5`yVfxW0kau*S0N|K&;*SR*zt{m@6kluA@S->r=yEa zTMOa-hCvXSs$In^+fGF|R;g5z%zh&?P9hQgoBW5iRoX9v2`w)ZK+c&-d+K;4hJcI? zvxG(XmqEe=R`AAJy56667+1;rPwh|{=;V7HM*MOx>VFsfDiNj(QYNhDTI#aoIS-=Z zk?ucNx^^|9ZONzqK6B1<@^F$`J&nla6lrEv8h{rM^Cxs5nr^R$uzsMy$E`JFj;Z11)Am)aC}m6WTl3xH$^;t|8%vb z@RcoD<#)s)nc3rj zy7MzeXms)QorH&dtrimAzn+bDeZ( zd`s>Nu$L0h)+1kG#5~LQ0q#LKmmHx>g!$HAkC{k^EdYMq1&uqn-9VdyiA8kabO`9h z2{_b2OBgD4@EEe9fTxJIjUJYqslrS)x0GP{xuPN#;e5c%k{1$o&BhU)7|8al;U?Gt zlK$Oy8;^m_210#AFA4r9L)~l0CTZUKW$(vX9idDO7T);YVp1Au92m7BLBk*?M9d7x zMX!KA#0LTxV{q+>?)A$+w2TJ`!D)PGZhqK7?j<0Cyv|Pg2EtCErH{`6nF6vR#Q6ZB zOlRV%|JdA7`%EbOxE?KHverNE`Rf%|RJVf%Xc%n7-E^2$c@maexPAs&^7}Z=kgo`v zn-{;n=>Er&4am7g5Cx`zOA!XQ@FYN+giHnj`vzEJ)Dibg(c&a|`eOFCqqxK@4A3Ae zG0HeN9EJ}o1*P|cY_a6=25FPF7UW0%6T9c?Dh3NHGZ5h;=w?WKmpZ=zJ^xlz-?Hnx z)7JDU_13Y=o@M3d*We|GB&&ZAs60*bBJgZMin-fzg$sa|FzSPC)CVX%4*q6tpEM^t zKb0A=ZyYodbR5P(q^U5&zIdf;86uNL3{i_x>SXw>Fu6TJJwuFN(9kVQTslFM4iq&p zV1%VZjRdV5!sM*7nsv(xBlIyJhkg)wcP{Jm^ASD__`F=>#q>u`T?j3~ALe(pWUQCKf5V}LKXD0B{zCL95@VEJT59unejdt6Dk`dpcjM+y zZNJ|GnJ2ma2z<26MvFax)ng6e+INAH2clE1{N>qm@f!Ft!DR+1jz?S9K@6;%x?1xC zoTN@m@vk74odkT72oDd$t(S+&|4E;blx)Jj&I7wv*XbiOfHU4H3?@x_TZL7X^??&@ zrApvIM@m^CuHdO#q3SQ#r;9{jSNqnYJLKNW$=#8nout}I(`BV&Ie8;aU}I)o;61sx zk*{3>`z7wr`98Pod}xb@__H-!6gj8IZzFMo;I%a9 zfB4m#e=n*{m=SSRtG!9+Z~B;?{*#qbjjbK+t$*8uQHHA8!2Mu!?BUEN3v2Gm-=`!f z{(}wSO&{~pwS#HBaV>Pi+a@-v4Dj(RQbmcr;eY5iA;@SUK_G~;;!{J)f>@cy&G{$dHl=W|4nj z{BHbXEtO5t({n9S_G&H`}^nZ6}oC(x;nx0c z7L>}7l2a^prV7gE=)LydBe!-VTvPvM)=5jwsfb2Rpu|>bm>nJtZj zq%wdVTM2f}pa6jB6cU03Q#-JXuGogo3&JKJunS~xa7aO(SxR3S&^M6oMEtP2PhC$B z%KSUX1s40_FWgR1magl1&}<-q^heD{rhJNU9Duk92?Ij!trBC26`Sv*HTYKx0E;nD zE(KoUtnz?+3mMUsLLMJ2L$N3v%kb3%!ZR=;W_oPDLjU4&Yy;=>xn$|}PH^l-#q++v z1BJqWAcwl3yaH9{Hk%4?iQy1~W8jgxI?v4ECt@!!0tb<^L3YKV(L-L{SUDI%K~S_X z2rSaTmj&^Ja0LT`f?c=N2Tovf$xHHJInXX3h&8a(?We%y0bXrKZzy-+`L4g6*3s1D zcE32IO<@5@7~GYmaHB}2bnz(zDMJVptSFUlptrn-?!q8iCp!qtmu^_&fu^;W7IYCquKCQD-5)a{Z+l31OK|($Z%7xHmQWZf@Sa8uwl;O~Vyv@yh zuP7O=_(v*2zTlxpZbb0b-=~XMxhp;;{0ZowsNTCl>BZm^1Tx3^w0m&{wf#?Z#1aQL z_upT7mITgbC66;}YCgdW3r69`WiJV!quso#GG*JBB5~>Qc_1&Z6oe+wd5V6a4_N+- z)c267buhjRJv{Ov!0zfu>n@_7fFj`4$=)iG?StnF!9PaOLn77(x2$HXza~|j4M%fQ zS5GVl_}Gx;9<;EVDsM=HZkXMuTYKu2&i5f<2Zrm$lj2F0q-s|z=s*b-aGv5il zvY)RVLlX1t89~ptk`tGw#rne2)((>ovws#|NXb~$-?Q1fmvB3TIH0+4{m{H7a9qMj zpxFL#r-u7=LBk95tUVV7BI2#%!(isGwqVk4E-zAhGku}<(46h^g8iYZ)~L>nS$EYd zRGtm!l`RbZY`SFvwDIHk!X=+eoBbj4*4l2XbR8EjYck zkd*Mjdj0J9?u+ewZ9>6m`oeho<4;aS+rRuquocICgZzWM*rrQckdK!Q{fIsG2Te9&9Y{- zuwFT{B`o^1V|Am~ALU@{+3bxU@9gZ3BpxzGfyUsOOg&W2;EWMY+KfV}1`DJUfNoCZ z(X@1Q;ERMsFF_`k24BGEVPdh!8Xd~vZGbiY1xRv?8nc~Rvo@4SlLNQ^I!*|bZMvHY zCmlE>d*5;_EGVEV9)jTsECRY-mU&N&aN|SSPAGCb3!@7dvcUQ2A`#zC9Isxbxnx)d)@)EQ|s%@4X8x z0wls2D%e*36aenPk94m*Q`@i%q3Qr}c>BqG z|2m!(FME%O_g$KDgI5`572XGxFLaR^h4~QQ2Tj&RouJj4&6IV_84%68y2sl9q!yB~ z7j}=uar;073+t?0%{LzS@zH~aUM%i9@+@O;(1)*g-MRTeLtp>TLp^)&gTV-q$GFYI z!3s{lf`acwXWLVaz~IO#EDS~X8wfT8rYI-~U-=sQ<0g=9fVr4J=czR&@FD&0#Wf`n z0fqTuIzK-jH21)392_6-0l)I02p7cLS5{W4lB>wcwRpxY9T508cDy>F0sUoK z|BRN0sCw$31A~#}=gg1^8dAk5bw*0dTlA^?yy1HO^A*-C)VNM>Gb;KFDoh53si=~J zcc|34O&{Wz|A7qx@yoms%>;xW*!UsWgXP3Q!Ge&N?!tivJsGmtlQ>^hMGjT}g5}FY z$zuzsO!6-Y$`)m2y{8)B!OOHhq~w4pCj!ufixy60xVgSrQ$i|s6NGrfsXq&YBQXT? z2{4rA_y3w~X=YrFX0$6JUwN)r4&Ef*b|ZCRx+%W(YTC8GV-*^=v57_o<*34ty{+5b4!zNwo)wIT}` zRM2w*-V7|kOMy(uGr)0yd#%x_=aR!u1YXJ1IqZ5@*ck^|`HwN|s4)N|4Pn3D)M4JV=GOayerhSt zeLrXAj@FlY!uFT%aa6xFT4qMW2$2WWIZ!tsV3l^`zaCPY9oEQr5ELVlBlhrT>DIpE z2%X0tHxoi$rkV5hT}LF|6dn)kJf{S%Z9zv2LY^;RZGXz?FgmF7zqJ>tzjXM2AJ2Ap zp^$Ke?8@BCi!WxW%u?6U&$TXyjK=xnm9Vt5uvdn*h*VXlD$D^69U;FEuyXuQ;%Hst z$Imcn=tHqg0>}$|(MZPs_rteDq|p99kY#lKzA|NZZ#X{!I%@AFh# zBr(yZ6PdU?yh?NbV)_?9lHmy%w1e`bCbbx4TYYl(0zP_pmkrumgKuwcA zu!7nL=@%d+g+dMj;2~GcK;SOTj3pWi15_OtsX#7~rTg1Bk^b*@3;unXzqgMPLBl7t z;fVhH3xBFO{jQB?3*p;TmW}4B0Lm^Dl~uPv7X2sDh3={Eeu}-A-B`1^_k7I zFJa!UpJl97ga@}8TAipd>m;@oWF$WkC+A%i`yok@bzeaebyq)lD$BR^_VSLT2~Y(f z2l!p><&+hrKa`^|`Im$-%VONKmtBzzr^;UPEO4g7F#ee#R8)xG4D-shBZY7{lklAl zc;homOLxSaLJuQ%oPMSNo88>}k4H)zQ;b)2B@qi$?S|w0LFw>ikjC+$a87&35`n~V zN(aaY!?c*PLrP2Qvd;T;Mn^N{^ySl@(yE2XhcL{igYs)0*0UYlAu=^TrdbUd(XlBr z4uiq23bh#u%n(iz1nAtklC!h3{h@q9f#HHTF}w|+pBt@E1|M#WBEefR!BFq$9gp=2 zmq~*5iy!-&nE2~dv-M$QFRUU0T3=5a0{QONke!3Oqw7UuyWu(a`+i;~$|OBlT)q}82Se!K7rGaHZ{NIuviX6DOV6nk4phSaKgGkm z#o*JO8IA2qUDQ~=?^fJsOZ4@_VhV2E0o97Z*77T_*7IH)X4ea;C97XUvlBz}?(B)Z0mLU?oNn1Hf}Bo5&~!+xfGL{E>6#^R`mL2JxECCRd_*K6iq^9(TYidDvnBh8Wy<%0VC78L_%1~1(=>th}*Cz?hoN4-Yvip)?97hFlj?+&>$M`6%m z)N6F#RW7Mr(7S=|K0Q926d9SUn%4qNKpP=nQZ1X#tr#%XGneU@CitZkv#GV z3c?D!>Gg;1#RSbyuKAH8O`r|fs*$OF5FRtu+Z# z952RgnS{c0Q=?QoocrQeSen|W(&X%(=3MJXdHxy0D7=sDYMp>wut&a*4EDYLx`>)U zoR3cwJDKj-dSeiO;ZrWm_F6)H1EZS>Y3YLh!_%7wV%4?Z; zY=n*X2se7u{(1hWo*`_cO~#XD3zMb_%pnd=PWSw$EEPRJ#&n9QxLn8r$X=IlG43JUiShmiH1t7O;@ zazWk{8Ek?C71K5yy|iMaLWHT z{3pgoj)dk*F0Hv`wUVSTNSZsY9@>d?4F9@oM3W5(@ik+CBsE@&LVEMPcsvt7|=~n-aA*|A;O+y@E$P)Ug~2Hc%E@q3d= zX|JA!hDO@rPuzgAvbU<$E~z{)wlylnR- zgP~+vhr$~@S>8YTWhtb&*Xio&8q8Hz{sercJLBnDj*J7O!iP85+?$U}Olo{i3We`#m;pL%~$x>I%F)oxD{PtP5<4D;SmR!2tCxGNsN z*X8y250!(XqcxYGd)egHty`DscD`GmIBue^PuTIL@%ZvyBbR-z@Vu3~%}&Z2j6%s( zZgAL@sdM=7yH4wb;xl!FrpMMO`X6`>W?`6!^;%%%`41ljz8FthjQT5GI~?TM(4ziw ztW{D9SFtub9y3;0`Rd)(kfYh(Co2Zr#rN!~>B#-1_b#*ek<|`6J3Bm~9MzE3xaSPv zYL1RYKQ&+X|KVkhqN8y+eY$d%UtdSZc_yKAj|zqcySUt5&CST*vD(qM;RDhR4b{~t ziF{I0@?WPQk`crMc6RRCvSkadki`C3K6H=ma1h~P_Wkq20#rvr-=r3o)=IUJuAW~0 zPDYj!tQm$`=vS*)rh;*1TUeGIhD_=qzO(Ci_vC?CAFP`epK@F{#A) zCG{_<7`m{?I`{s%ETiwHalvL0{t7zRi`#{qsUCW2+eMknj`JNUQINV4FteFq_xjdS z_Tl$$Ua7s1UiU7LGn__YZN_Bf_WA~&Gg7w}8{T*8{IFT~kc!jpF@;?u7o)(H+g6_& zLXLPU9k!Jbk8S>B*EBGckp7R_{7$;-qk3lJ-zn>B4nL7@IQ{K8rIgZAR?KGimiOHZ zm)h4#t!IKZ6uRC_I&y<*{>%x=Zb`PYw^*EZu8aDill_*DfuF(T$*X9u5T21=ye2z# zN;y-?6zSjt8hpxm4Ja3!)TLBrgjy9GFHxu}9W;;G?A-MJRj6aEJDoyxR7k!4SEii} zoW6zA!LMZ8)Yx(&UNk2?w0GaG&5%Sv!`k`QlXZB2Thru$pbpyx8lz_lt@OUHl(fMIyE~seL}+um6psVcqZjfz~ccUZ-T5V;*_lAO7Lg zkh3j6Q|-6<-}E>qCyIq*>W!yTd|hv7|0+1lzu918#z_1z^?9}4HtUbs5vtdO)`i_n zj4Vysa4G4ns$Wnh>e6A;cL2Y7a z$jLwK$c=EPgOd58QW?j2WgOWpyb9QTRDC`H4!4ck&I~tv`TErddks{=z8^PA_H@g$ z6m;uc-Ejh2u|1kxk>z{hVVQL2m+|qQs#Ayb^z@YcCSHd~KjNkREdln|&Q1hrU1%_0 zT*u4HdsX@v`aqH7^{u~B6p6%uq?SpmGw1fr{TN82q^6TsDx`vSCJKD%RQ)9i?6Efv zCJ8X}?kd#H4pb3DNVFFPQ+riH=hu9MfiUyQ$57L=B0;3Yu^}s@8a<`HzPlpmSQ?$> ze_?Rm3EB4ZH#wN;I@;TbgM~x}W32DDAw+$BZ|+ChuMBjqxvZD zaPuX7u>~X3GdMkZ^lKq_=YRhK5qhR&Y|KHtJdYmzt>-xMMxG_9?SQql00ch5O)I#) zn}JhkO`56@OvR29wK7m$H}j(f1xGGo;M$O~#Xv(92p1>~Rahq#!-zM})y@_eU7A2n zxb#(QP8AzjSl$ybCy6WieDkU`Ha$DVYzT4jEsZeIsDny5(>2~|C&P0Np$hQ@_@_X3#TWZlZw%g=78{p8wL`!`GBs zNTMr5+PVEGZ$x(QzIyCr&}E!Elm<`s(yEivZD{4^*+2{=+DNNYAkw*uJWw3WIm&04 zhEIA@=SQ{@QxsY?+>pyt#K}#J$E4X1*`;ccN;{;b$;c37aE_*SMaCHXT^?F{1%MP) zQbJNSPr^OM^5a5K%Bf69djH9)3^Z(7M)Z$~`xYsd`q8YLlf@YGy4r{UxVExz9<~|2 zvQprjzkkma8Rosh$ft?^%7ThJ2KmWdS$>6*MwTcdI;bKb3$Y_Tz3)3U`HY@wR zu0LHj7>L{)0%5>d)7={=>)WZvT6F8yX?!UXNtd843T}Q3a)0yqZ-vZGqvT)6|Fp&J zTCzN0VbOu;&kl0GA=ShrCiV`E&YtsImOb@6osPYde!YMLQgWnEKFtBU=Qw5RcV`t; zIBD|cV4yh%5{P8|$;!$uP3I5K6EYoQ3SO0WZ8H3<#F}IsvT_FRV6TnpQP`f>R#%AH zg051x;WFQe*TSIETalg@-8JqeqXK~(FT9-dYa!rQT^zn(*ohE;w4bL9r}L-0>jpob z9SS0GNH{0IUr58>z6ZRT8?w4ccY#~2IwF4KPDgBEORIIo#W7tf%LznjzRo%+(?Y(B%GHOx6-W7*5ZmR4>Kv7@lp}e9f5{DHm|t3xpyHG zhzuf$SCMdS!o~zpCP13}c~yh=k^kMfDa=`h#~j2AzTj>P3kw=5O0urs-ZtS?bmsvO z3dkfo7mQLQwdsd{oyx`yTEaZczYv<_-D9x3`)SaHl{)lTRiU}}XZ_l0e5!x<;v5`# zGVqe=i*pJF{}o1}Vf(2re6ael}U}Z!PvIeaenDUBlYCyyjiCoQk`}J8E zH7Niv=F(Q45APZ=NI>O&pq>q0mh@!d?n&@4D;{Fg7QKy~u3pBWoh^vP>ykr^@(xRB zZZc-Tqk2$VTbmBc{WIB}%F2GEF5g6~WY>7leDR}48HD!3sOic4B|fH$IyyQw7(C?n z{cS9~*+^8JIFu>`4EfQo$AF}$;9ybsy91XDD;?U`mP&u}_{?t5;qvxkGM+zs=h72~ z+y9dDsr`lx{TNeX7!-azv7zBQvcjk7ZATNSliq50na#Y8PdCT4EijR?_YV6GV|JQk zI~w<(M4mVN>Gk6+^`dhT5*|l{9CI%BcXJn-Q{FtRwY4YNN@T1;q+|B*^C|1c$HUn< zG^xAAI4(q7?$4HHIGXHcYI-N?3tMt3?Pu|RHaktJ@eAE{_lr`PDeulH+D&@4W{7@X z7E5?;_WretCdCJH#y-h)uRe(D^;qIH%-Jm(!`Zf(-3?Y*?eP=iI|ktVI#fjG`ezZ4x`<=a-@o0iYWXzY@-(Mi?G zacq)(mQdsx*;VQ+z|u`?9L~eLE?ccvWPMz|LZZy0H%sRm_Po|_-1&n?P7~ib&{9h# zoJ-;}kF!R?hTE27qTbT{nZkSwacUo#H&u$3*}P6J6;-y?VXOQjDZhi-`wXB%`>hw1oi9kVuC9e5!&IHs0BJ5ZjVd#piHs6~4$YS{P4 z;?t|oQf3V40~LgrUm0<8h!4n$g&7y*o4UquQy$t#ZCu6Dz;tGkI^exPxv}Waqnms< z%VTy>-qcp#nwXwqM{DPKSEwb{bo+wB7Kz{zNg+!uJFhCc?J@oFYCag&>`}Z|Dn>l4 z`S_I5QU==#su9->Uv!!iX8FP%S#68bkNv(rke!2uZFGlb*oneC8aaIEgB^Bx45w!& zt{sZurZPDcTgb~Qy*a_zmZip6Yn@j9U_!Z;-Mum%IWZ~j3n^~BYpNWDdyLfdxGMeR zZQQr#Z{_E8mYX^T>p%VnkBlOana&Yeu40k39NNGG%P35J&C+B=8YiSwQhhNEWm=?++f; zaD1+#0Fa{hw&)TG1=govKVpm0P}wFqIr$km2LF8U!WpWYN>ZCq6R|X2SUaxCMw6PF zN+eZ@!#Qe3T!06pS28v>cB6-hyn)u-(AsT}o!hqAlax{r(Ah6B_hN5xS=m3MZmFwO zKOHPuJ*(x;{TkpKK_v|}SFA&;V zT3RTOjvnPEZ7GLur4a`bOij#m)NrMK^loC=w=&mr-{S*!{ZJPQLd**h+p@*z$dT(8 z9^Nsp)VOo6Z*t4lt+6BtwIfYplAg7+H?HTK7#GtCc1jM`=#H)~Cx8N&deP4TL}Zip zRC?=HTSB2{Da1E6YLNURNvJfC`qP;?K%J$qiK<|6{H0cI9 zbKyvRS?AZ!L@>}UNP(OA;sSeZ(EslR7z$Y{`IKp_x}q(A;(5ruw}yi39$mf8CUfWK z+azi|S2N~KjJOiC3t0U!UOM7bW1x#LFF#uAKc$CZ8y&A+;qciLU>YMC2c&k%t)eDM zO$aE&)y@Pr19wJED75EW^XM$umkHOR>^BiQG(CxOYSW)o`={mwFF44|iPF=Yzy97Op{0{`%6_n{vLa8`-G^_$&5H!gqsS&b&LE=j#it+tW(7oRuhO@&B12ki(Pq`$+)M^&4NbSXEB2Hh->rQ1 zr#^sZ(*8d@ZApS{t`Cf@tX?#&SA7AT)A9Uy zX!OWqa>#iY_IdZ6lTpb_8mv1zd3)5<>=Yt+Xd)W&xXy52I}4ZpMoi6V8}Y^%a&3G` z$9tJ`bAF_m2$ak@SpCUBGsSEci!1oHwjUq)5K+D%g5F*+yIVdnO~TMlo-g}_la>)M zFYyeL?0PkaqzV5zTyD2$6W;1QK<$dziFowE!7AwJvLAjHhU@n6kqxfTd-O<2+5d+Q z-WIS=`@qu$yw6@TK6G9ZU@9!^8MvUv!F-h-J76gk@#f0sYaw0OXB@ zvGivO1r$@b5E1GD(?HYVT>LMwjJ#Pz%YiB!92~&T)PO7f?uu~T>wX>`Wv8LM({~7u zfZSx#es@jz43)_82Yju40*0Jag#bU%0s%2taXhzrbRo4~ym+yH${)MRUot}%@6N@I z*254d1H?ibY)Fhggd+iqD_gg3-@?zDmYUj#EcbPw+5pz{IG6wBNmWUwVaI|}hdli% zfkZZDI$e~5NTDyhb?Z0KbFQ4zE2}HgPdQ?h^T;?h1T#+>?G3K44~t+{3k>|X#UrD| z4m1y1EOZPF=|xhsHtucGj7=(js@-J*-1p*zK5;ytihwN2#m$W^BQ`1N7Vg`=KmGP% z;^MzRL8j*9lmYjkgMo~K!$SvY1(wmf#tO^xG@2uZeQO?#UfP?LDBuIAh)s(OKKL9*M_<79l&mbQe3L8>s7jlOrSk902oQf%lFAZ;^Y<9G!4PY#QT;6W zSTsm{D>M;ed*=GU`j>Q(BolZ{66VG%sKRf=H2~$5o|XTR1Jhmbqk#Ha0LZ#JI**)J zzF(bP3?$B4peDrEb^IKx1L`}-MiVsX>~3EXW%VmvID@8H?8c2-!4ZO?CB<)pN185Z zXt2@p(~3)PW&w5Yy8Cap{1IwC%U@j-neuF=-hUq2xVouT9Q6qiXg(`VBVvAhQ|xB@ zmh^4xd;9flLU;^%UvtUa)O|#kfipY2ZG&qy4Z~$A^W(g+ysixy+u8OuYS>)hF*8!9 zJ#>)zQm{D9CYnE6-v(_KvZ}Dr=Z>xU6t~B)aiB}zCMjWgTgaEx_W@_L=nn3dixQ)J zU|ir>BWQewW0BJQ!=_kX`<8btF*a1a24dSYmzuozVyMQX8T-7q{rM@*+(A`xFrLlK zSZx1mGj9;=!s8yJPMtZb#ywZQy@Put+bywtgr&^QB8vA@&WP4pAFO~0Yy{U1^z@dxD z?o@vcJLXWoH2=k-w4BA5d!@UPFY9E}r@n-H)w~IoVsA#`DSR)-6%;Uz&pW>87?A4BF7~GM~$J*l=hpgh1Zxhh8ycy_`+*kYNZY6OfG4@ zP+J$JncU@O8uCTL&$^Xu8@p<2ob2(&fe#5GUKvMRB45Z@Zn4ugXR{abGrr+pkY`hw zAbhmh$fh;;8sqV>lwk`_1;NFTE7f(@$C9G#Gk7no=L~0fdd>FxJ36B#OX|{Fy{<}N zpC;VsBBIt}^rxJH(?jDLwtjy6D*S=H=0fuBazbBQ#&D*fX9O`CtbF1f^D_;Z)3NRSARv!0c9{Z&zR z^_Y#{B@}*3&ugrF*OfQBCQdEPlHM*yQ~Ekor$@ua)%IDkQ2q_!7_zG4*i+oHQ$4$b zW8NK@R-NKl*;t;(m3b%2_eKJfPUDsy4@O6}D#YLZP?Pz0nwMA0XeQvv(v6RHl{S^; zeC`jwkITJ${#*pdH27b4#Z8F(4mkHjG7e1t6Odgr5$Zr;IQ$C>#lR^8Yc+dTPYxG# zIJhPoi&$Qr2$Sbw{(e2V*R{o2-ug`}W{VkFTVEhx|CZN{Rpa6~?sJ8M9vp;)otp_N)TF$Adc=yH;ooV`e z0|_m5H4a}6`@P6t5v0xMOY#`qWn@_D?8Pca8=j?bLoF8(;O$C)in*5M@2bk@&WfkF+2M)hGb%#7ziL} z6d?2b>o#i{Azba>?etxn2d@AmOZ`t?>vfZ zuugF@C)N#<$ya${Df(r%n<-5)`k$Oh(XU= z@IHZ)xP~6D&;E5558ZiB1}|<2Qws&s{Vxg+;txg_U|986058l?BQf#V`X>hNNhoF; z7@l0P*XfzC`uU-H=lS#JG2Z1j=#Sywi+*Y+jko}JAp?0U;g|*-`Mj` z_8`xox@G@AwX#wHM>3l36`DU zrK`&bwn3Vjj``X20;b0w#rKXHa36yS^A^}p?Djc05%1xYA8gvk!~A*n++1+Xch@s# z9tcnh=^EdQezoKQ1jmAm`ZJH>UQrrxZPH}a+@kClrbc0L1ve?<<(Okf2`X>t&e}E? zoF~%@1uj|EKV7NKjaIOf@dSY%j+V~npMK~ww8As9l}Yd5=5>5`@{EqG;I4NU=sx}T z>yXOIoC)lYkSFY#W7r7?i&lMah!@1x0+>2Exh#c0!vWU!R=!z*oVT-1G*- z%KuzOejEq_e5oAP-LLq(3gmDQ571m=Ahvb*bq~sNF4~72ygve1Q3!bKmcEO zJ%}$118LvXhlG$sA1qk_%^OIkB|R@?WztOR)>(j&846zBezf#5;b&nWgDTrxsGCMc zAcHyQfo5f&mX=2S3`GRyF9k$gcnunfwd~*)Wv_yVD|Vx>vVp9nM^0_?h_^SM4&mOw z&)*$zXBOC(0h`kVdkGxIy*-aVhODiS@d;nPd;yfhKGg62FuuaArKRN`x(^*cH^)+d z)`>?S1hz4-<3VJKBP}g0NmU0Vx34sp+_r5L@nr?r)hG!4h`=HCysm?Q#6ae^upN+( z#Kfh0UPDjmnzdzkPzuVTCItpn3jU`=Dv9Xzg6|6G@|Qp?|JUY;&1SqS&j~j(V9o)x z40vP0auFB>dVF|c8B#Hc3HF`B8~H6tumPep#^i*s?+W zw>Q>$@Yk<7zz2TR;4lwXRv@FApN5V%@Et4!o1(wxdpwK<#SV6eWwwluQr-qejF4 zZ{=%HU}w;BN#=uK1IPdUEvfcpf3w;)6#`BWy!Acf!T2*y$22V^xu%dkV8 zybnq82KzovcCnb5fsw2otTWHXrd(4~L&!87N`Iz@x^Q`X$(O^0l5=uiz)-hIb5J4p z6)O90*z~Yd0+&+FV9FJVvzl}%{~n(&dEaX&7ng?CW@t8SY?M=YdIrGUB##cce6}cE z(V_bOp+_fp%Z`^XzqW4ja=?sie%ObMxW0XQvbXwEz&5OBtlUcY5j7kIb0wf3o|sOYuB z61NUt`;72U>O9mp5ksJ^3nZEL2h7=^dk3%547l-1v{}E(-G3+`KpCB|*YUjxyeE|g zs@3X-^Rih{=U!h|$9g7%;y{9K(u~#$xh2Vj*Yh)akgtQ0jL9!fJ9za&*TL8X63Y(? z8(aQB?U`*m>;LUi$y*;?aNg^@#TA#Kr~6SdY!Ma?Ls2t>&ew{8)l1*~h1e(s?%lgL zZ~7d+)79oU3bD1VXB}7f?cOF#mAm{Ws{7%^6lRezW=1!9YQ03g!A2q11a95YS-Jc= zc?Bxw70SXThgW1Gw{*soU!^>>jZv2UODe-Vhe_9G8C!*qNxm&x z7`5J@Dd$XcG&&`Ls&e;nnNzyv^gJ_x9GZSH=I?oA#Ns5s@|kr8{`HT1RyK0yQm9Gn zR)0{QZfB{_l)89J#Bsm5dtcUe%0W|R-P<5`vAHMwx_$GfwP+9IYyOC2SiC{|tXed| z_i&T!`?vcZ`Kz~`Xexh1m)RiA)-NdTZQAf@CsozS?cJyc;>_Rksd3G0W;R-k+StN5 zTf1;)ov4MFou`cL3$~OhEy;BD6S!*X@DpkkM{b{cd8k;4Sv)LM@QX-dx-@;168mX4 z#_6KGqfXn)Y+uRQJWqYCF}y0JHm@O=r9hQKr*c*HCeK}e{J|AAd)hB=+G!mhQi&Sp zgj!V?&sVG;ddOxjeE*Pk)jB1*%*wG%GAHbxt$7rf`bfnbrM#dj*QyDW%L_joXTIk4 z?wFdEIJXR5W~v*UyK;9nD$!-gIG$*BQ!cjs=JG6ADxX&VyCVJV)p$V;`DEeyh7CLt zc9XWXuf^FP=({NWJ|!c5sK1sc)CHCrKkt?28ZY}w?xo4)w`cDCCm*)*az=zpWrui+T8VpQh{ zet7UH_iLF`n_>k$wBpzqf+b8=)~i+-dZb3_or!R}OWTmou6)CI@ge;hH!r`XSj*e9 zQo3jREAPwMczHE?Fy`w?t$H>7Z5k0TOy{OBS;shj#_@=-x23qZj?3!d%udBfowWxE zg(d5PZW+(+_WJUPtAI{`A$L-Z3qx)S`;-VSM4^yf%Ix*` zG&eqm)X=TE8uztxxAaM>2JSq#nvuFC;WkIOf!aFf$oR`Tg^96(F^nwCJJ=7nubdZR zQ{Gt+Apc^3F2O}4-%M@Yfnf)p8mpBXk5yBTK2Mg+KQAv&$;>En<&XBh<9r-K;cUXM zw9<4QX{gV;Wb6`S_7Ty!o?c4R^*rT%%=}o2thQ#{BeSGtmo!hscLT!((OOT&g+?5x=j$!W{L^w;$f#vy0dd*NCWLf~geCn0!5(wGB~TEW1;9Nl!^wKopuKGj`7U=Um2ziQnDp&~z_fWIjm zNo?{u=Y<%;e=2NXjRwj)#$3(hz#H}_Jn>rLQbfC_J54h~MvY>K2RX%h#gBRSG!m*;n6h-&pglch8Fk(WpBoi+=?L zr@AD99A?n5eNVEXA}c}BRmr=fyOO^cRd;5Rh&>SV0B*6_mQu?ls^53c2E0pJJSw^&)T{o`g-3TdrM0`5{8HG zN!p9CtBMQcg4F{$<`5K2>YF?Ty-3>nF;8?;Slj3b&I>><_@fn;@nOLqsK}lU#efF1l&y zqjPsYE$$$05f(u!m~y}Cg-_uv0uDs7J`P~(g|Zzxc8mcxC>t5odtt_7xR`E0To?)^ zY8eFV0%yDQmTZ7n=zw82%HOVDy&7Fo0-^tV?UmG24V>2VT{n&tcUY~Xi-02zhxH+U ze??dTp+@85lzhK1`1tr7<9fnnAOZdKX4e$p?MKgQR^M~`1mOKzTwGMvIcaR%h=cPW zg0muW-j|f%LnE3-u7fwMs@k22Lu}9V}S-jUIGr z@)%$M)d>vTck|{=<#T@63k1;kB|a*q%Ym|F=_3S;xyimHXz3^{GgL!VNE>CoB(ba5 z?#leHEPi{NlT0=xfdF7zVZ2Gs%m+lC2a3tnn8{}eLyA&;pEcA_VU?yL^b_p3Zr z80gd^e59nMzhLKqE)kSxg0vTuX;?MbV~7%N9q?^S(8@=kFo@BU-=? zI*gc@*b#H{c8sU{1{NIX*!Adm{^)MZsmFE*MuFHk0uhc%fwpZ*$lsVd!Yl90cB0mA z1g0aMR82f#Y^`WG$EsxvN5@DP))T`SWKUTsiS0Ap~#f z>l9n&G`gT_1+)rA59cT~;C;ROVaniLqqD;iyaPb@Td>&SHQm1b>*SsKR>yk^x&yoh z(HU7;dbl|3A*ihN4jr2J3o&}LJ0Qsl)HE1yY{{Wygd((!jcO8iqcNhE)LfW@Jh$II zs4WI2os^W6FiaKcyN_)Xb9gzal(K#~gCQq$1HxHx&)ZO457FFzXfEwOP<0VW1&VG> z0e^mUpaZjbVGZ$idN_x$*hv|*@xcYbz!^&H$M|(<`c(b4a2&6;U?`+ZQrh754NWKu z>v9lH01N}kfN~$Kz1-h#PIdt220yhxhEbvImJ+qxV{%JkP_&(jq3BFOV z6YY7Z*pso(SXl6qEd~{AZ6Y{AvAuhX)kH@&AR`j61z4XD}Or~d)U!2g7=7Q8+Z!W)xY0q zC5|f0;kh-v?NiXgiR&)n7o8&=u$6=hy@vD*9<#wFIW?67VpTDxKvofi1YOLuFgu%7 zqwF)L&6ZVX_1A8r$QqYy#P|5n(5FGdrbd{^$nc~re3+5fZKr7d=KbZVT%vZGcF6=(iO8Xs-qcO6mDEJqRJuu_ZolAJ`1kb`fMM!792K{ zZ9Ww9s_3jNgF0oO3}>r+n0HFt1vj>rg~HAogO}x|PqHzyACV2Tyy^cWuXcSz|8k~J zLT$79!G!dbZ3=-Gs7&j}4!Ul!om@0j`z)>DkaJk-Ts`ZJ=X~LE90u1J+4?t#jA_f| zvp26>{K-zkMoUBa)-t)?gw0q);A0khWAvAp@u%k~_a#u>SADosfzAE8sayLgsi!Aw zV``G$8Bjkk7CpBwp}dZ=!MDMFZ?UbbS97X81v}Ny^S>(VMI9KOpI;2kD~Y*UEv(i2 zQ9G_Mv*eg+xwhyZ$GtwBA}){LG9^AgW$k)}o{dlIOsqKE4o%Aj$?w;TVi?fEmevj- zwhE(}Kwb?K+K+j3!YkqQHB5Tf;;i_DuSYtI=>{2z)*EEFH=TdNn~+jjdaf^iVUkV5 zFR6gi^KP}=!8eT?T;6)li58yZkytM&z|wmM#g=Tftcad4mzI{FW=!GCo-G>uH^XeW zo-zHhxa=JpUXvJUQKn&JA09VyCbE&T%WOnJ#*Wc zbv~~Pchn!*~vw!dyi@6@7FSZ_!FJollN!84u zmhP|Pkrm&>=5d`)DA2lL!NBy_ur#GpiF1}bvgfh6&u*imj?8BUA_&?yKtOTB(y zXw!%d4h2H}z|>~7<7)2o-}s$RZQxBY_$6`h$xm3>tTDO!B*jFoL>`(!IHxmOjV&#o zgE=F{Pkb^GX&4`!Yv1wjuKEv15>?Vgkrr`8U;GCrRi727VLmoG}?_6D{Y@ zK*7wXQ-4ZW7?q40lZqzp&xP8QWc*dAoq*=q!Q5=n@I}b!Lcdy?BasKmX zI~+nSgdVK+%Cm-z2lx}SW=?_eCg3t7BZJOk3CBB8T<9Xgpu7?r2XPPl@ZZZCeLz(b zGK zvH@AY>N`OoCbmW#M!|Xs4AfLUsI++h_5`=$&G=$dYIxH;nORhh-a9FP$yDm z!KC1K-4#%atWjcf6)4J=tU`Lk)0IERQ4p7>0r#2Va&cOQG0d>h`qMM>JcD9Sule}> z=;zOHLlCcZz$+IcE?6bI)pr_kG15iUR{MA5?nk?6o8vu@#P}0i@b~0qG&VNETi9(i z1)HoRjGE-R;&VwdVbEqsU2{JNv&m)Qv`x;)i0N=ZH%1O9z)HA92Ss-W7{319rD}+V z=0*6BeU$i_Hz_7GH^U1+Oc7`_%mx^o+<07DcV=7cE=1BV#L&(yn~?cXcq?j3rQe5iaHH2;SlmbIzL?;+cfxl)h-nV3QQy=9$4MO zhg)rAj#yj&7a)X+8hqq>Ae@8*=E2`Scp=J;S05HsMP~9tk4>8t-LC^f2Jj`r>Pf0? zv^Sx|HZh^TPW=bB6+JP~0j%337rm-K5QjD>{uY$}6dGuYCvl9#oZvo?3JCdqh3CP+ z-rimqUnj5f?Y(vwFAD8zp(PzQ8o);x9%hWh1H=qtVB=zq1}O}LSi$Jc!*mfGuKk-r zT`c;8^e+k)w#rvO)qG;!1{DYEr5$9#X6BEcyN7}nHlrEd!?#Oai0wKmei9dm7?z9a z>0GEhJ*%^UMTi7NX5pdbWiF=BbFelgJzdkth@FgV!4n2q12F^&W7Fs?9-R(K5<-AK zasM&EH&{w^3j*-Swiz(*iUQ$)wsw@~WWg&YHVe)In{4S$ctyUe86piBYy+&QjCF&r zU>Qz;mlwyI`k78I0JamVPFsdKQEX()=t+i7HBs0E z*SMeYI9^C66vqL4is)QG<&T6kY;dG101q?2oSy&iEB+yL8a&a%GCX7h#&C^9Z2d_n zt|w~w%SQo@&v-2X_`^VmLj+rM7N*|ZMGN>~hp_TbVpxN=0CRoCs7RU* zHvBgz6s&e^#9F91Jzj<9DX#1(!;_)l#*GtIol7ey@h5HpjP>spB*KOM6O<4TmkD+0 z`|lvIQFsXD>e-}V#U=F@8g%<#zIs&-u|H7i4%pfXBUjJ3SbWE`KLb(&=LUo?%OK?{2`t>Iq&|ho2FP~6tQYw40)i%Cc40vzSI98 z#klXy*p20T4r-_!sCf7KUrGRjcyg+UK<5p)U0J(R8#U8ua?$`4t@`~KZSKZ1RK%p5 z_F^2~%cfD@^O|NpT9NkI+XEW+%dZwi0{>s|Dg@Idn zY1TR#Ok!*S?8!n}>8yH2YHDzrOUsEATsqm7>CH~UDTj|~Q}V%V8tcBd z`!jakWaUaq5D4#gVL7ql5iFHzAmS&j*K%6L&ur~R(AbbD#~)Uv0A~5TD1kH|DMre!r(@*W$evWP^Rx?ZB8VBB65R#I>qsCDr|CoBo8u+nK4`%V8`x+-!8yDj%kGi`I54PMy{JYIfk?1J;}% z=8octtt%l*CY?Gjc>l9|tMeQ>zNwbG`udia{B#EdRi@2#*88LF$>ZK+Oby)G22 zICQ5sHv6ZLPszGNRgsqEoGKf%{B_LlZI_(-q2qI8#g0P3xy+}wq#(*nzQrb7a+af_ zc!Sf)ZGSFt-@ksi?C=49d0jONwclbo;}>|=b2n1E`>PpT#%qWxW3fpp{A7SXqI2H1LMtW2@lj- zvmFjp={gHfL%nsR}LZzsuel7(m{>F$@>+&+@nEH;xWcH3}{6h zt@(P7pv9ByF<4SEv*S<$Hk~6G%z;bX!s}l!R1rin0w ztpFfg64;L(hQb5UtbRZGQ_$f~>O|P5DlN+$PMw-Xci|x`6R>TS=FA2?mO4WogS8uy z+IsBDaauI2lV&4WLv~nk>5qupoh>4gM$enGcPH^_G=o z!(~L+L%?<7Mh0X%J^YN)Hem39+p(XIe;=k};nM`HMmGn-5Yae5oLKPEfkKXB>POH! zx+4gAj{^3uASt!x1HPf$!-Ejxj<3YzZPc73u!U#eIq&R%f8_^^xS)r%pm$TaLKI$n zlpt_jfH8)FD=jlq7xfFAbEalyvluGKz#B&_69cENrd+v#h=KCtU!?O5i6_9M1{JN_cv zLWE5KWRjrRaxd-!|E3h3!WK)9C5B878BTqbJVTvy(C2b?%?xO>TZ>u9>|I0s3r#e} zKkHX8{x2N!vKR&c=RZxSCz$Q+`gR7J|CLtM#=(%nE+}XSUlza``YL>p(G(Z#EC7-K z%>btd&-4MiYD#i(1apAPCvWDJD^Ww!PH?<{sA1%OeWve38;IAt=F&tD!HR$-o_6$B zRaGj22#->1#!Ce)savq9dl4O#gua&K2*yPcHw-#61Qm*DN#4GV>ERHlTd>L(Ox7kJ zlP^uYQ~O;C_vzxnHsUk_u@4Z0-!#%=C*FkDKzd9ERB{O1YrtVwlDYNu^h6*6zz`KL zYxCSp@nqY}^RD9eUi4-4y}h#j(uXPTRG%>(g|_=Omok?% z9c~s;h0!65rMw0EhPToSJfDMj0*s7|&_>uGOhKwe1w9TX9J73(kuq?Ty`e?G4kb8B zZcUw+86K;Z&hGAFqM`?WY=TwI12tW@GZ#kZVYC>T>*Aw)EwagWmz>;1uHkh%9}6vY zp-w$^Od@w0#z-O1?s(Xze|^!NU3AP31Uv*ng(QMJ&E3RSghvUa>peGNBs&ccDXOrW zPoL&uyBpc!uxs7Y!gLoD9ZKp6nVS=!$_$~=QGoX%6Zt@xTPI{=+m3zGjR&hGSdKBA zpmmN`SZ!s#9O7!23?~x!lNUtt#4A)%qF1&jMJE!PcBd{Eu(`oA};3p>CGw^gp^}2q= zD+ikIVpJ3zRvcjq#($Gn7>z_|yMAD~h&SN{6l5q{5Z+N$_E&oeV%@=COkzG{Fb$~e zXemV(qECp^t{qwrN9+v5lK3Z=Elg$Mw8iJZlSf*6$;yUR68qFhQBIHHU_GpLp_J6v zCSfjywzp>X8G0M52M&(ycV>N^ zxuzK6`OmOjGwly>`U81SAmR!e2r9DDupH*K>D5#bBjCAZwn z?$d3qF%p?wzvJ=vfYF<*3`Rb4{;0q>zSps4E}EU$FCt&g=JP#^b&460$dC4;$`+wB zp=7qvy5gPZaiAmF$W5MaceGKdJV#Nj_xk(n?>F82*)+Ufcbdn^`Td%+d`E9_BHz=O z!lOZ~$-xy$cedJY>gQEY;vM>6Nmu+*=D|wH)@_4!;_+u1+e}3}2K0Yhm-hLvJQ(3~ z);cEaE2aG-Cdxr>N%iiwxYAoJ$Gk2Ueml&6Lb^Cn`HB+bS;4(k2gUDibG@q=9KS}d zeH=)1!X|+`msWXVgK}=H|BpeIrXMqvJeG--*L8R3&RNVoiW7L3vTaN1+E_!-C6+^J z^4E4m{74Pf$=XdBp33a{BuZ>IZ&#blvj<0w4O0&^){BM8gXySB*Q zr&FnlnJ)gp6WTtDqwbqe+1--=x0iW3hMOXWCLrPx-|I>X2S43|@wyYiQwh^ol9}J7 z(HP~$S(l2^Zh0)1BbIR1rZcVOg0v@N2c2>Iqd(^J>!rD#@dyUb=s&c7`QZAJpGLG@ z>!a1g(o_MeHZ|JmBXvIY@^?2%^{*Se9$|Tt^WP4=LD#10W{Q&g;@lkOqg-Fk@ovu- z-P-ywT>%i7w!_vcTHsxmLpRTEN?qDxA)6lZ)0V5ZEm-tVOb=O3?WB~VVb%1W^R?Bz z_$AZR>Pdd&`WA-ic=q=kS`WVKZ2e~wT(K!A`&<}huqNy6w?S-TenL^H(wrf=)Lk`- zAA{U!=rw1{*NN^CqpB;=yj0p=rTO{m;Rv&o?VCLQitBlwjWmC7sYR5%+*p{svCd56 zp#3>HkA7v#yR8;SvQx)>b0YU2Q}@@?Oe&#KNfbF1AwYNVD>QL;RdzrTe{7M z^7Uub-r5}d)~e*qC25|Dipx}Ei&3)L{I~uO+syaMhsouNkLmSX5M%tJ#JkqV{;$Or z|IF>m#YmsS>F$hOM>ai}meQ2Uma56xmm#)WQ_qy|-No2k<4UpT%`BCpJlBVR+!z*c z5Utwntkbr##*eoXCqi6&{CPM?B+&*BHidNX5q zL!Cvk)eT7|OiK_Tdm?1|f8Sv52$-Rgz4L%T3=kTZmD&Y@p&}>Nv`X4{sjmmScUM2X3xUVoS zdn26o_^&E(^siKw-N}ItOA7c~IykDw5yIu1MZ~TQjNbPDFH%fDqcGe@MuuhYOIW3m zRMzk@RIqTv^;2)Q0CS6sKj|77rRZ1=O!?#V)_l;5L+23b`|au)C!Q0W2sl!MgH>=p(OQbO)R%CH zYki%9D?zFxHv7=!e{azIOzJb#LZ~N+Jq^HQ`xx=g5hpQ#%ND-GU5JMwN#_8cuo5Bt zNh6dB%q-QT?OV5wOoy!PEh*UrrQm1HxB&QQUL0=A3)4?EEM4WH!af7 z)cAk7E?>|+I)biNBI5zFr+~QIg;j>eiZ1B?EtG!f;Y3a1gNlJTY6!_r^gpeQCz?T& z@}fzH7@@J0HU#U0Dn(|!;lYzu1#v32XU~7Z6mWA3hs?t+UWI@gzY`Mp@j-#}G%~54 zcsd)S>rv@}mHf|r*nx2m#Pg1KaelJzH--|C8BVB3K~sP}iF$9>sptPi#KRN*2+UUv z@;|il`}JwY(FbW|CC#IVXZ_!ZM05e*E*QfdgaW$Zx_SiGLhH`dAG=C5vZ;kwqlu5^ zQ^4eQAQ(P-Qn9 z-UR%0sjltAqr;i+jh=vA*kZ`s6~OctFm6S5jdQNNhJnne=BKw(?dC^Oyet@R>X8hMl_>&mvq>`L6%rj0{&ZeeA~8d zl%m4^z@m4(M;f_1cJ_!qFMI2`8MuzbkdRb4IDtU+5+a1=!ZZCJTjLt>y-3XU#UFNJ z95;ZGp$*NF;6YRnkNsabm|zz}l0Ph5@ai?A6`5vRYx-&Aw+#c6+a)LniaKrxhJ1Ye zx?20&%mfNf>DhoMN%^N9{GS&fL12Ft1e7?32No}{!L4I&UqDfg8vl9fmP1ISf;aIr zwhbbys;vDxjSxhUeTGQc1A7CLt${(}xu+NQ2@q|2+nvXaI8;crO$MK0ALceLt*sdK zuOr(W{HAtb^MOP+L#e>~kKPkdDsK?;M!Gg|8)wh$12OY->H9YFo?^;Z(V*%EIKogo zfAl=?ArLP@_{|x_?2x*a(i$!{4UO=A(F`rMZfDc>t6A6Q&{jcH;DE|ZW3>eSR176y z7&vrq{bp8rj7T7zJ0$m}dMG#>JV#1%Ac~PkkOsc>mTpALc!rH8L^-W<+ey$(TpW|I zM;)}ZfQ6si|Iqxe%ML6T7EK%;Hym&F(BPn|rCz)C5?3CMPXDjIv9j|f0Oc?_eZfMI zj1%EW6|ehDoR<*FP_`fja-FS^*YjJ8fhxFVLajglmnB8pzO8c?USW+tuJB}m&|-q{ zf~|_Q_@NKy7An*i*}IeAA|OXB7@~HZ8O^IxINC(9C*ioC#h9Czm0$SSOj&z6h!o`2 z{)ahnqOe;nEd`8qws9p=0qaT$lUMvmcMA1+`3~2_o{~@W!kq&yH|#H`Z{qJ z(1?#_8xk9Y{fd@B9{cA=3#Xu97vA+t`I2xVvkI!bz~Bm+ZEvOjWMkt7wKR@XfnXyX zlDx*3%#KIelQPY%kG5_n(? z4;*MpGJIczGZhwMjKIAbGp+|6kUXXSLu4rv-zIw6VuDUdb0dqCOmZZ!2gQ5yqfCGW zED6%J_ZWu-wly#Z=qnq8VVtw9fe5n*MSu@>Kj8w#Fu$--^lSa*wcQkKG8`S{OlACXH*pviIAZ7v9|Tfs!w=OJg;k-e!K_;w#mKE>E)Lp$K1 z^L)6ic+876{oD1tx4PC7t%E~R4k^xm#lqL(V{EO~u_?Uh8C-FPb~1GJ|D)+Uz`5N2 zuw`Ur&yuVkN!gSnJ0nzzLc>V16DpOFy+@Q#%1+8kkr5FhyMz*jLPwHNLcI6)|Gw}4 zy3YSP*EuKS_dL({xj*;k9;QLnc9R?)7aospu~RM@$!N5 zoiw}|7XMOrN{MoR9ec>1SNWIw;up5amaA<%-pTfT^T}|>$o&g0{-XMJ2Ca43ugfDh z_GbC{d&HMqS85CW64L#Hzc!X5xKPY!J=L4}yOSI|`p32#-l{ymP?n|dy&`bB#hp?A z9P?xSos_4d9jekNv&Q{2RAnAnD)ewkU1$rQG%S=0%aig=3#6?ZuVR>bw%)?Jh2Xg;rx}QmDX0bgTvuo{ev`CWqoV(a8i_DH$&iZrNO8&g^UvB7bRSw+4N<$SH zs$^eb{90(pCv33PM|eIq%If~u<_zJh;*|R_Ta>Of^7<9gm`2?GBJ5v7WB*cL^5;zv ztv|QWRZxfBNY)p1zd>_Mh^|$NV|QWJyD2@P^3^4Vfj!%OmIf2{a>m}%W8~Ernca25 zs4QVL)BQ5LQ5Id7jX<9C(!M%mCR(n_XjhAJtZFB z9#aK&WjB8u^`tq$oFDt@V&moQX2+5cc!Sa+yeB$y}L z<{BSswx8pAkLfM;`tjd%;|dxNxX`{d?UiD=T1$6!tHA0zQ^)BkajhJR-aAf>o;3ae z%6vRa^5I{`e4bG;_wT*5so8C^>6Y%J!immlEDAQA}~; zVL5ngNg)INm>gbvnYneSkFhOcVbGcBD}{#`_F}OOQG|BHf@jx9k6^jFWr(1o`9_7vmLI>OQ1O4C&b9LoA((1rk|ZADaORC8cKLREXTbL}6|4?k;Rc!+K2J9{z zJQfn>pL#WRa{1?}aq#QKohm*?#gDfdrXlDs_)N8)Lx+NCQ>19n`RU#I2)Wb)p!6u< zTJft9qdKHZp^EK4K%+q_@|j5dWo6?LHsS@STkg{VgzbZQk<_;0wYV3Azi6Pn#1Yx& z^M|$QH7@_6km(>Y2gn8UEfoADB?=)=Kva@aQk<gv7=*hLUSg?21oRlYkUZ1@3<^=0ja&;}r79QxyCJ9IHar0VGX`2v zry+NrymKazKqoGLy$P{5&%<|81>d+hIdv23@QNmsCK2pY>r$WY#D~!XbX|36hU88{ z0uPH(l=2J)!I41vE$lNSzYvQKq5uQvpbos2vNjH3Cb(%NeFc!t9xU##X&Mz;TMW-% zRX$@25Utq2F(G=t%;LUuk-Na+WT-JX9 z#N=xIi{h>B`v-->*n~QvU+P3j8~TKKqOu}#9=A#eH$g-v>f1qmSB_Lxd;7lzBawSVVI1LR|q z5u_(s&?2yDljOE-B=`Q~#}gu7H}dgq9l!%!<*zH|=s4h0~zh8m|9x%aB3in4QZLU6hhzQWWZ z;nppE2L}mc2*WEckA)6NP(k@9-WMcG6OSXzsh2UYZriaV2c)7jQzi)w#CwOVL*B$Y zOjgPGHjp=MOjC!8f>s;kV#IZK7#mMNzrDcb_YtU!q^+nKY+@aNxJh^+ep3{_wP9SP zIz2S?sZIkMyRBf`5ZQtCv2JMt6au@EcNTh&ho9;NriR@h4hGPPj_bZ`N=4z{@CM=e z&GAC0g+BT7l7)zB3|L!YXY+4@#L+rlzogV}z z1QX{$0(ax%g_0Y*h!6%%!Znn2sxl+D5R!1q5}QD(rB9%R@XTAa)TX-N<;zHftBDkO z5NaapmJo#!rM=$q@KE(|1eI=rV$Ua zTH}haS3-yKC%}I z9u|?*2VGGPj47=cxz-UY7lxi3$d*aGtlc@G4eLo%0UjA#Go+)@b-VsJ+flo79E5xLpf+8;|KobIF?c)29^E_x+r zXhgZ2(=e-JBAZo7b7cqL?CI|Mf4o-T{Z@4zX*HqlnF)3ei=D5M_iTX5=HDA$T z>dfczQhdvi`+9jd(ClT3(vkSOe^*=)?|jCAf3Iv;1Vo$ts3QWitF;^U$=op0`A^C! zH*=rzj*A<`4Vyolk{A)?bAEAf{>IYanNTZh7tL2aFZJH+*j4VDI^dLE#TFFy$LQ_c z@fc>ok7KSrU1g_bpJXUc_MT(t`RN>SHZyhc8cD^VX|8!cJU!^{s#lSP^Mr`g)Gu~$Y zesP*ZW+z~D^zmR*NMKp}&Hs*VY1X)O?0d0rt40w|F8zzEMd=}t2X2_9cKZk4!#J>mySoL zFKyjddX4w(gQc@#_3*TgxmvtE!(B}05*G6`JtVt~?%28QH@lQMrH}N`eCB_U;gYd$ z!{41e z2BAnTyCYYpIY-9i8|cmNIPvO8be*PJtmD72DjSk2`2k!gFfDte44l!RP)yDL2NP*K zV#6&*`tEezzHx)HQ=#ys^HZj6iQ|S^Z@Cvq37c7m+Jg)<09 zf$$7s7DkrIadm27Z_ZMS7{n8JAty6BxbCvbu9>d^+ez>MUQ@aPSi~a<9;zQSoiqjZ zeDj6{I5aDRmU$Hw*=Ix<}M!SWl7J3aYVIct}De;NSKCGw1z?w@?R)&Y=3X10u=c2@yJq_f3#|j^+1<8`i=?vhAkDb)9fmQ-u{!9=eq+HTDzbJ+cUp>KqT1_5qdyJ}> z)gI+?By;5GJT`un;{@#6{S0KliBeE)MLY9 zll`v0e-7IGSp=pZGBphZ{6O0`)+onM%wV>rqHP-iegsjh;}Z)bI%&McFV#h*i3c&%+nvkvMX!XuL-7^gDS!6fLuaRmzy z-k1)kGcj;Hg9;7ap)thU!m)x{Q_>KHDTaNE4kUos=Yl%|CM3X&c#~Ez;d94)|1=m- zRdfW@0+vc-mkrnc7jRW0h`i@`M9F9>n>iS;#04F4F8(5q&(?dsi8#PZQED#|P1CI2 z%SP_pSvy`o+vvVBt0k^9P*Au1pHa6G1t1r#)8nbwvytC6S?PrJF*Nc0%u#rO? zRs}rSo*-J7>co=LojT8RH9E-zTmc787_chuZ^Xexf+%pejHab+ZHwQ$TN??)k9T4EVs5RKE-K1(AAzafkXLv#98R+L6Jk z3XCcc!Rk^vKB2K|Adkw3<4QX zwy>QWsxmKGqO|_ZCi%t#bL2Xav|+591g_zNd7{k#-lz^zzyVT@k(L?|+>)vqrHhTz z7o+`L@NMwOxh)M#M!^-i`{i*ucDkb)JmVipbOlU0^<%Tdv%@QUYDcG^gJ~;JK1OXU z^20DDY^ffKLs^7~$xs#5*Z%gC=H>wKBag~D8^&fa36LTVBEcrl85R!Xp#HytesFkk zZQY6UO*|hF8itf@oL`8j--p|p5D!F>8Ii#?HiyRVP2Bpbe)^0#}JW{@ytDtPtS_D=Vsl>544TWz)X z*d=5Mw7Q${QAPZi2>*0F;Y06RC!RzRjga4S>jT#c6TA~HFj^)a|3WwDTS%KYy#9zT z&9@u6zeIO+)*kK7oNbf4(6v56frC>c*(Wk&_Q1z~g^nV?at5Eygi);-D@AZPuo05G)$jpHqFSO_x`V+)@emD=th*yJx388JEo1C z_jg}U68FBRyHM@F>C^^Mal@7^j7mCoM)BpQ?+aHn<^KqhFr{Px(0THVGLGrtAm__Co$~ekAy&rssszvp&&dwUE6}LyyiM9q?{z^n~ z9~?a0M~xlH2#r+vZ{F{fypR4eedss+PRDCsY-|(I$e&a$)UEa2jf$z9mG|)3t5!je zI9no{>!JqVJf7~bme5)Bh#7F|uE@)N7s>&sY>vR&Mjl=??7&rbI{S#*S^}X6UUx;`WSsljf zso9(J3X__lw!kb5MT4BG3CIa(O0@3>APkWB$RPyEEc}OoyTaz;%J5?D>0->tAbMuB z@18-!4~Y!ftnQlv{0v&{fBh>^vf^#5y?$o&W5NH|0yOmg=)x%v)fXRq9ij4TApa!d zQL2*;q4V%l(;0mOjY%wXL^%ov!+EqDp6}VmN0K0As=q5}{avfvJ?Di$bCNj=fpZ2H z@kF$b6$&#Kg1+)$Wcy^|@22Hys=}_rF!TZiY0DGw@6{ zqztW8r?ShjdE_NO-B?iY25_qM;y?K8M#!s#e&LfqUTIF(M0Qn`6-A(3U9Mh+%f``< zp#w?&NTs6kAK%ndd9DQ4+ZQCHk+$1+f?Q( z-+mP4W}oVljkVa(*N{>Tg)Nlr^~u!1j1K$PTkhapXP-919I6euPWj?yx$)yuy4rIy zg%2N!aWg&r>G+_ym~ZMv>a)_pY1oMWhruAE?}2t%{P%urW`f9or}oxn#Uw$1g*c#0qWlh(ruXOu}%H>M0DP`#i}sVSbS>1RJ~RW!T^*@OZT zxiJ6H1~=^Z;i&F^#xk$8)bv^g)PcOi(%5~Rh3+cL>XXumqK6On&}1wu|08k6kfJjV z&4bmf+{f8XJ46bGfSeZbNh|`hHl~q8OPFykG2GjF$*R|Zv1G7vA-Xv@k398+grHf1mW6?8&+^c<+oC|_GgXiWf=`|vJfDJnHfq;418)oU(gV%2SOI`w1k9A z5;%rc9Vwo}cwGZ9`{-mf?=?Wocr1YwP*+s+Y?O&C|1bveBNB(o)0k~@lE@yQXMr4Z=3#fpzlqCUu0E!w}Gjg~B2py=?`b zOyuZe$^!rT?mF`^ZH5bdO`BplBXzA7rBkjZhJ}UkZV|Up6Ec>~9bf(mS*bAlXHa@X z00SuqkQsk*-Yvs9(M@HC5Sf;Wl}#MjUGfbOeDDfM0QDh^T#a+@UF-6kpnD@ZP>Pr| zMY0{r-iO?Lx~0ICU2x#q`zKkf=#u#uCD@^43h1yP7VBZE*^ z5@>;K8?SLhIclN_27{gw`3*#$xPp=f-bdMeB5exTMPfm+_5#EJtoL}1_#q!9+D$n8 z$&v%Ugv`*c>}^9nf4|-$977VuaL^5*x@C-8kQMTr52D^*Tu5;e#CO+crh?*c0~Y);`8piNT5lZ{t_Oyw6&Op@d@ zAWm(fS%t@%3}c`-1Pl{NiZAkppAFEbb~(;jdwo~FcYp2pYycaTsg_i5PW6dr(T@x+ zMm8^qh`oQm<$BIrak2LW%FYRP^@&u@%DpmdhyLtmj~P9@Ba*Xfy|bo1+Y{pxQ!J`K zw00)uJu4YE$bYxVyz4Kczo3hy+lj3syW66$lJB^m-5+$eM9?*3SCI0MN546zeNFK9 z#{ZPG&&W%Ca%@br*}1s2$WlS7Jo{;B@|zQDIwITHcMFd7GOOE9bUk6&;Pv}Ty2Yc7 z%Q|!iv|p#o)C9X3FwvIl=U?fIzZ+-3`k(wE1?d5%9WCyWAC3k0s&VKZROILE*()_N zEg!6>+837H9rn~OrTfo^jfaocNz6;O+|L+FI6%NHLzC(#qlX=cnfl7Kg9E zAfY73ApfhX`MD)0A7$E&87(%mn;4IF6zkkjOlqn+&&GJ#r(z|p{N8zK+l*86;k2jH zQx4Jzir$n8l+Ds*bAD3&GFW*isH}Ux$FbgL_d5pjmHiHEplo$3+MuecR4tg=nlwAk0Sa^!^UTzYgl zzj-;+a5#TKEg~pUa#txgWmD~jz(3*rh8K18q&;<;hI@h%?d2_`3Ro}x^xG`)Og1WU z`&fI4$TrS6R;Er-x{l(vi~9x=tX-bp>I)QZwEpi!nm9*cpVLQPy~4f?M_?KUQSn z;`lKme&NUey}j;l{Mv~6Jk@id+hVx>89(h) zD~;eYc}IinEY+yUlmW{tHlU*Eus4`)X}muW>`!n`f5|5_=uq zJSzMHDJd6gI8enu!}LKMNkm8mgA%MkkeUW!HU95WJi7MAFs6%v5CCPsxG`nTA*KJ% z={L%y7REJ23kRoa6?M$bahyblU~j>Tv2G7p&p_=9=>`%iLKr=cw^n@;a2pyFVxBvv zanYp1WTCwuZ9_yhiEil@BF$5B;>1e0aBtlW%GV81GZpb?rhB&m|+I#H29p+I8P0APJEB6cfH^QBT& zEmGdl=B0XIq&`*rC&2rp{KA*xLgTl?|Ng^-3>NSyRY1L5yE4mvCPIPZUL}tobHc1S zx|dK1$!CybWF*J0_Oe|Nig*r`C18;R2JoMy$3nlZmt0<<<8!?y3;v4ocyYc0A<+WL zOuS~x2PQ6E*C{Jx#zIVSmcnXvtf8E>Y||h++@vr<{s+pm8Ov}eQSTf4hh?)oY_!ligl0&Hic;@0rFA<6#czhD(>c5V3Hw7J zik-CVwr!bsWsi?L066kB8l4jr>Dd1PAtHba*pi-wN>#MmSyURi9CWc}S05G^6SRa7 z5t=Y{o#j%S-`2S9L?UiOaTv_5$4oT<{Sycd=uDt*lr3^v{~4r$ldJ0&sO!h^JBkCY z;njs-yXQiBax!n)?c2bAJytXuMSb_Y>$Yx4ewi0-i(ANt`*)?K&&4 zNaf(*sK;9~1>#d63lQx+y}e+1)B>|4Nx}|D%N{5M0K)ve5}S1kQbXy1LqU>|LNoKp zy}Q#7KzTs-{NcB&#@8>r?MI_78neiH68HaCBD#eygWzDmR-iLO$Y80&>gT}7U({3p zJV*kmPS3z4^ZGBrUcO7RAl$M1F0cE zR&bqSv;G-$Zb(Q7gLE|FK2`>=2Ib{R*jqXvdXcziDA$QyT=KOsF4wDxInwn3-HGIN zq;f|7W@ToE`59_l+N}4k53N5em-lZ_TITOe-_9FT{(8IFxeY`LtST>mD938k@i3OPv<;^O<}{qRJQ`35%K zAWy!|zfN9$=gy^zIIeLWMRu@q|B08xL5v5T+}PyV#wV{e4V&jHxK!xQ*nznup7*Q% zTqav^{y|WZpP!%c?Af!Iwe!}g1qJw7BcHx7j!XMsZEI_b`I;pBuV>q0?Bnw+qaH%) zw$R~YcwP;%wxQ(tvDN}S>8%jm(2h0!U4FVI`+d!tk)6lhP;@-=km%OkT76Ed!}`

{Nc$71@S}PHyqw+UVC!S zGLPLWV0*HCO(>-aq$x=g_*!!fCtuO9ooo`b!*IEV5KL zYHredq0yD0^NX)Q(mXx-()I<_okRJK#Y+#P!_3_uIVo7Y;M{I&lmyVWKPhZnK9Sjm zBJ0|!&?>7xI%&JFu;0;QEOXY4&Hkrsl8|Re%;{%=3cvWA!-jRXE-Q>2?9zByo_X2$ z?V#x!k#(6FW{$K69y}AE)*b##`HyLA{P-nW$JGzv5*5q7Z>ht&uY94Y=depaE?w^E5LV%yl){g_z^u^%kSwMrQKkYIPIcW_8y|?PFR|*GobSg$*TZ=WNp5 zS`TSc=Jj+Ntmt>l@+MHs_MyyXVNGjlgY&Y>cWc2sE+%NIpz}KaS;^_>hua%} zR^zl|WM(E(Jb(x51O$i|M4^dNZ!^O!>Epa+B7LbSBtl3SzoNqi05iP$X7O@a9UPs= zOI^38tm+w3p+pDz{JDl90!4HoT8aU&s$;K7D!<=Q{Rg?iG6^HHY53X~5PQ z8&770SJ2zwt^N@w17UKI;(Z>83%@`qtT#{5_IU9+hi6gwkCY)NSsvm*BKciN^2>1c z#?(2S#Jb@(O5)?E6Ujq#>S!3hW zMV5U?M6sdJ_`(&UHw-HHW}Kb~F{uJJ8Ibx!#Kq-hW{zB<--8VhbU51Wo;bU-HaFMx z`A=(=`2{uPXjUeq*1v?J?UGda$6=5@LA3r|$TNeot__kL*~KkMOsz^uwa!3S#5n_1 zF1#H6N5p@@|8|YT0^#>t6%-PI97c|t)ZIZg{EoIK<>iRvR$Y7whKA zB3>CAnxFVonL=#5ycDr<<(_*nzjfToc6}}a@3-9Wk*E8-NAY%t9!6#n5fRaDSJ9yM z#_N^~F6sf^-b%39oW12Xk{h2q*$w|QZByu zc`(mm1AhdUx%po(Mt?!TqD5*HKLycbBG`neLdp9E7Ys&E2<)`Awb9Iom8%%T1wLFZ z{HR!mJK!Hc9D4ok-2;fk$1SrJgFI}6%%$Ic_o)jz9ldb--TcVN2*NUs96J^cZJnFH zdX^|V15tGmp@pt4Em$1E@gE4h;|_6#N*J^aF-qbaL*{3LG)QVg`k#Hr&k<)Fz4(I_ z|2oM#*q9TQ9}x&bVo`t2{f+(3UAvCp15)ZVNvpwj_3M(Y+ApX?*8lwr324M~O(0IF9qmpkD&NCC94Y=RU`7B*%YYs~-Jyz%h``Jg z1w;g&Ad!t@#j2#Flz>Kwwy#;}&Ots2DRo2Y1VR=OO9_fh*T8_`lkXf0l+K*Bl906` zMwNguI7oF5A5Ki!RhP#=8x*`CA^IRUHx$cHy22TNLLwv~8&QlJS^>9GQ@8nsSl+Rb zkUetb$b0lwVFVtOc4*4g8aWjxQA>OD{)DYrqXoh_m$^F zhk{2CDLqh2=3-|E;7p=HAhvp&0c8TE0I%t*u^yqD6e^xG6V*Y#jw)!%*oA-;9zoCZQAr2{ws@0&#$t)ZYft|NvSaITB3k~1cL6uk>$GRxVUk| zk8oERL-JYw%^QViVD2(7@4%M&CR?8kA45cBB*)|&@ExJ|p@%U+lqfo}MxjdPTW+(K zDw=c+_~dVWPGV2{^S~kZt?KHhzH_v9{XNLDKNZZtt24JwHj_UFOpOrq6MhTtVE?M5 z`9;UhorOmZwqE_o9ZxIp>=L~?cE$(QW@*U)9dPXYsb#y+&JVwRUl+ob87aS6NJz*m z<;^gj5>o99<%ZhQWHzz7^D*CA-Su4~iK<7dzmr*tTC%xIe6iwNRCi2VWYq`5FFy1J zs)d99ly-P!TMbHebYvvD)|geq{YlYoVZUw4INf)hGv-V9sRAyqv-w*6Z>eh)bV^Pz zDHh~SI~^T$-_UgTjMjs9j*Gp)%PB(J%vCQ5w)%WyP1mdpn#s?lF}Yc^G0=Lqp0*op zu9f(q%&eVG>{1sVNvRJvEFyAooeBx3jw|94& z$kGpX(N`)7FJD|$(d-B=%x9EJ%q{$v<0bp9FzadLrttZak-=jAK)uT8yC2b+8PPJO zF4!@6xoN3bfoI3j$Y&cff!Q*RJ}OyRy)oKhT@OOuk#YrU#?PL*7$fcZ)5anRwciv^)UOd_{FMARzK8a;}c z$=UzbD%i6ySUr+8v48UTZYyoK?(@^Xvy~hT)&iep7F|<5A>n>S(n_Eq?_xOf-lp#D zZ_G<5`z`FwP_{do2xp3g2G2Z;q68|RlCXym^TAF^DC-mL#WSp~8+m>2bE{qbH*q3D zZ(=g^cWc(1$@(Gp`RKB9l9_&PlIMB7qcd-h?#ZSVe@M9>SEnc5?))m*UBvzlPnUl% zwMqXAyITxG(}25U!xOKLd$V57cA;IicrLxI=l#FR7sFjq%AG+UGEASo2z9CJPZ@J# zNNoIT>*goTq53**+2c61DSfMiZC66+eM<&fB^{w$y3Ioyp7Yx?+&LQjV9=h?C@UsA zI`qZ1cbW;?^d1l0J+m%)L}T7zRnhSA8lB><*zvGkus*;!t5w&#CD5bx0nTry(FC_B zb!tJ$OUCGa?_x>5Z)JkHnVB4NIY}}FrZd=-tAzCfmai&4jQVhlwc^vYiu;)l9zMj5 zF%70IVwT0x2{{~$ppag5g47Egutwi8#5|+jq-zjHieo1)@IY@OtggNV5mIh@xB*70OEMVJAh5$T7VY{t)yH0=1mE`0;oxf5-~RJI9?uuk^cO6`$oL! zP)EY&`*Y&$YD#N1&@^069F3-atC>R~=^CJVh}egMf>(HV2|6HDBZn`D7vx_ZqNO$Q z6G%#{e-wc?%O0GCK?PJBW`7g4A}FQv18j_pP^F;ap&lC*Dw40WJp_~jo9SAwZ$9TqMwzcwhFM!e8vBs~xi3OBiI?V>h=-`aM? ziO`oL!no?70itwG&#GtrJ2rQ9pJ`*@^omQ`{C;BLh}vF6h%hqQN2L=P7}leQ4jp;{ zMIZW@NEFdEj@~*<%kQ6^&t`Buic5@*C9=m<&X3ZPTBGonS$R${R-IsYSbpF2zxMX> zUCTSv)HYWPN=r3%30=jA!KGbKY6r0*7L8O!&YCHYqeNrpeF6BcxbY1L_9nV3U?fHC zzRA)IB(DMB78ndjo)8`jkV!U#H3;K{F#+w6j?T_xbx5Q)&32_fq54}XYE&W3mH4Tb z%|$2kem-|{___)@)}q$8GUc?5qXgI-xLXWhma#Dte^@}eXTeW@1slX+MXY6r+<16o zNNlc_2%1p(L)NKab%>_|>o2HgY4~q`pEgOPWV*ctRwR*mS=renO`T+#ogUg?b1+3N zkaU-TC~eOCZu|eW0L17A@D_5ADp+BwlF|(*di_x;BfQ^0(Ex!Emm#Dh0UA`q2egL5 zm5`%1zg4u)_F+j+V(ShCIcSW3{kkqp1HoXrdU~E%hZ6vayM5q^6PJ4h)N>-78so}H zl?k$98=928;?=8PPyzs4(9nDk?e&!w68h66i(Jc|?}=sO&nh z&}Lj12()|3|L+;3zQlvf!q`Am9c#(ssYL>I3-|+w!i0yi6E|7tz-1`JfO;g9l+4<9 zM>jPf>^@jBwEfG?Z;Alm#2Iek1v++*BjrgQ{aXyQf#(eq8&^v~XaOHH6x2WrC{K@1 zhS0(gB_gSk7EP`oNy{8MIInPm*Lvk%f!h={p}P9|2)idSkF~X@Xipj$8IfxZX(#EH z#MuG~vc!+eEeRsHvAtkWs()~uK2Sns#@i*_Fa&&UyFkb);U%@&w zXc#SMxFvM}EpfyI^*t=xtF>IN)PJh%Wc%E;2OLjejspY^gH9C&wi*Z-2qE^zFfnri zdFYJ@+-X5zf)2@jm`JgkK#GcptSlQ2Q!KvWs;bCvRF3As_@D)$1+ZL}l^+V}2~fpp z^rLRgpAhc`1`*vSaM`jT6(#;soMdL^xj5c04POq&a9)Zf3A7K_x+cW-?iIuW19Jo! z0cj?%7ZWTmPvex{MXXJuuTqU4mQxitO0j9#wK zb)PA4H9p;V5+(lgr$?#v(!s0|Xf$!=wBOUP?(TUGItEBN72aFcE^dH%Tm%(7;fL(j z$JDm9w{y3=z!C0=8fEY(p}ZZTE2owz)t3wLz45?dZDEIvfiD*6gbwgfX{>$StBvX8 zD*!k2;?Tc0t~i~mD#8w3Eb%om^fs${M~sh;uWip7uY0!|s?A)_7|Waz64E4QdtF>d zvT{1C66^CJG)-%dsFPYHY0a6I~T?~s$#k9~WKY)vWl#j}qkALqa6;}-nm zu3?oOCv76(oW*rDhi<=!YgSTABx}M8>%!6Zyw^N5>5kWg4&2s0|9e-^(;^e~6V{6* zdao@$6p1a{+ZRr2vpTL08@&^dHv2ckGhg6aN;f0IW0?Fi-R!}s?^zcux33=l`X@4# zCQ(M@>KXfrm|V-9-JEXguD+MDb=v8_mHq>3aFyvd>1FY;y(|ikO;q*S?2}3+-DYNA zx2)}-AG@{G%4Rz@t!vE@Bz{WN<7l>jq`7{U7zM1sXXxe~N^5!Pz(ADZAm#)z0o)pwxG}zZT$$K5bcG1R=O1u7jS5(bQp61^f&+~xSwe)C~gi5J|0quU}8hvW^H_4?Bx+AN_W4Z%( zdRx9KQ`8*$q-JojWHHyvt;pz(kZu1zn;AFp9sOaO&tw>Dd%6vjJ<@wq)IkxDJgZc* zm8Tn_b^{T&1*38e^$Ko!*jq$aO3~^cW3D>Z@UQGshQD*bZowSq$jKG&r?f%*6R#{ZaT&1KCM&z{o|;e zJ9E>IgzE(q?JI*3sq8`GR@#38t6Spjn@2B<3Atn$_lB_Fd4DvAiScSg=px;SMUptf zwPMw4eOdQ3dIy!qi-pA21RQ?Z@7lEs<^fo59k4zjna=S4LF;z^i7>h#1_p@B8iF>+ z4bx#k1}6wqtYdoPK=4Tm*VL2Z4MPe|b$O&!JNYRS5|VCE-~GqO$hmT2B!x09*}{ zqcN{b%_3#`zEfJph&SCTd2M>>xi4u69UqxZd5G_!D0=1!A)&eEMFK+BNJD zqd#}dgvMQcehSG;`@spKIhCw=5!{0hg5XA=ocoZR!HUW;pj5n<#DR6)?I2-$0RP~u zXk%(m=qD1cJ079gm`!f`XD{Ym>!QdHGI* zmMcKRqKj960?vOvh1dYRxw0W*gZ%<#9SGM)vZ7@gX8w3VT88#=&q|U{SW~mvYL4{gV&Wo{nYKLuVp6l2rq++W?>6>JW- z&DcC-5@g#Rcb_4G-I+i13kb9peU^&-Fy4#0Hf5muG>})Y(jV}T59$Fn0EW!+a$W!i z+z0qG;^6t=DJQYCQ5P!_m2>%f=__K6z>Sps^r`Y>n!{l)IA($hoNxt2Kfg;gObmH7 zmWc%3$w^0$`Y*w0?K^(Ew*LuW|;vm;$sj2ZKCN$WMk`4BsJ$rTzrS8`g5`?My#>-O2|12}Iu}zRW@8CUk zEF($iBAtk2yY4ZJBPdBF)rDkNis>t2b_v`!36`Y-=Lifmb!`VA@oC^M(o_!sL7Yuk z8f*`8v040%#`sk1YeD-y1BxJq7&Hg+tvX^24rY)hOao98zKl5B1=N95a3zc&>5pR$ zs$KFk^kBx|-bbn|MOB>EpL_e}4Y_8R*3f6P8Dcr$Xc^r?lQOG9q;?;%wmygvocPRv z_{cMH-bb;=HZ8InytGy#1=;H_5FFO%vpD!^RS(#zDMfzpqpsK8-CpZ0lSzdm2{T4s z28Nhlq+J7868HrnC;JXv>)MbXyJhbt%Ve+#a>dVoThY(c^Ka4t!3$k3C+;)o6mXj; zHf|O|hfHzJFD7IKA{UcsFChE-TE9JjW-tbk#>VUdDe~N1cxO(kd_uy*y$~Y_?dc~i zFGF*4Ok)3wbHFT-mNyGZ6sEH`*1D57larYOGNx3)TZ|%Ze0u)9e>B3blT@YtZ(s@Y z`56QjhQR($`}|d}rwX!3s7SItNzv9l)Ztnu;QlNSw<-WJnT^37*jb-mI;Rq@qR>+Y z9+(_WFHS%T1pYY(J1(3aM2`Uv5eYd1V=SSeA%rp^khX-`Mu7^(3f;v)qU$8Se_Uj= zWs)qzqUs-ZI6&>m3Jyt( zQ1G_1pkT|+?;y-__~mnOKcdj(keL}}_;lwh`Yk@!;Ho6A3~A|HD7W|gc&5XB1iUL7 z2Kw%Dhrw-8qSRyKcd6Lxb}<{^z9gAq&v3jz(g30DFUXIn_1HhHEMv$;YK40RX%Wp7 zQTF((Nt~V#MF6l9X%3|Jn++2!o?p4Zo<*|Oa2z9`@G?LwoGtLu(iCoiV@xO;u#6%Y{k z65oCnvQnY{>k%Lkl;kIQ@W2*98CgD`k<-fUMYBOjrcD8Jp+j%JX+!vPL5X7l4Y$siJy#Uu(hZJGe`QA3|*V`vdx_5FLbR3sSV^%ZB zpvyE$T z_j?$$taMM@}bv6c=}m2OnX@GJl~O19Hm{@S4uIo;G*L{hw+2h^tQ)Ls%qgA)0Sc=ihWA1iEY)xsx*Hw0HICDkPIXg^TuOZE@JAT{E?!{Po zQSStCCClvD?QA{A_MHD+p76_`4?A`DAp`!1_SMHnFyrKz-EfIU~?fXhejh<%mu-t_dZqboF z`Q48@^cTf1YU_V!ndAsP{P1*Sz2-54&)dy@3YrV|rtInD7C1lmGT|k~b(?~!YEP1O z1pj}Ui7d2-nzFFG(Aw=dBQy5hR)6l>;=%lU$Cl}ll1$FKYm6DjrlsP89^?Fu*MBph z<8^nB)E;D_y zS|k8!{*wEj|8q^Zr|-FigUA!ridqZvMq0&CW)=Z&kj5R%O^Bj`H=p^iYbsA9 zvBrXa6P0v6T6F-zIT4+(h%ay81wp_j9Rh%PwjcE%jxB)AE&hwPgk=N=n;!qG7Lyagn??xnoRlGCwWd zXum5!NKg>kP2c|^>wVs!b+MTjhf@T|1E(xR4zPZUMo2iI9uBJ6?z(^guzIpm_lZgl zv>e8=GvGM2lx%Ixz88Yuvn+EUf>CU&@q_2#C8a6+f<$)$^Q*cJcTKR=9%3Ir78DFr z&_;ts-GfbpC=!0|T_LL=TAqjXAx|BThL|fd~mfNS+gko8A1Q0=O8Yw{GYdI^}88 znfSfY_G1eIwhXYJxP7JcpcH*gmh9ls$bdrdC@4n|$!pHEKuVhl_{+QPZU;X0qyO@> zajHTn)ROqLG**7n5E~ytKAVcL& zMZ_xuXc`zAa+sHq-4=?I3=@7{SuU2?djw`>%*fEgyhRbvDju2WN?tIYa4G3N{SWC~ zO{gW&b#d7a>RET+KB99#K2pl7X^l1Qsn>ta7Xc^hq-^_!+o=K{jIl$Dso16V64)Tf z(GKXIAiB_&InegH_(m8;NU}NLMG8JwfF6jAly>d9^mA(Z9_A@L;7N#czy*m0uMMuNsKEbV6RYy6AEDgoO*w4v9d zbOjBa$ia~QF+xBT0)P?L?nTfQ4D_WEw+)lwPM5#*<3!*ch+c*Gk%d59ZP-AL0OD}B z#)1>a2Ol(f6dVgt4`FJB`dJ<&T6%Gt4Ut>|oKHP%?>QJ;dhkI}kpmb2jBSJfM|qMs zyN$&2lNJ^h2?6j768@Vw6Y%)rtZE>k0BDH44iTS0a=br- zkyOiKi_lDFubU_!!u^JVo*bMEjg5^YDh|6G6+nA>Z^FnR`kBA~r+_l=xS_)NB3{=$LSW?lSyJlPxgkKa0#bd-DIj13 z$2Ej6}d9?#uAc0WgCig5op@QNT-&yy0>3 zqRJH-&U#Q*W@e4p%|^z?@_0F5Sxg=ftQT=P$O8n13#Y1Sqqn;BD7Jfhgyl}k@a1(T zQw+7^n{%d>nA7qO-aCK~78C6mkcOj`s)}=Sq>USDziY3?pxt62Z2}p&UH4|d^^xC} zG=4yFMnbs>ph!zggBT$?LwAqIXDV`SacW_l-s1Cl9qF8VkuhEki3DaLhx=oW5adC* zkNa>Bc}kH1+kk8*l0FX3?OGsp`FmF-W#w5cuBj)YJh0y-DSRgp-G44%S%8dPmpj4FZu`G+kp@YN3MG{m(n3p`6b(tJq*AF=(v+44O>Ig$ zO(NP$L!zxAN=uQnh0-APf1TIwd7k5Wj{kAo_kG=0bBR0s7o+EOjkOt~5iPH&scH5b56oWI8&1qEvWA)dSdv^vum)OYzS97Ez{|TZw0YAeLMcU} zz=;@h?ZywEFyn*I2g@q?r}%r`^`Af7LSr(UdMujv z{q`-zmz12@|FE*AK3#L}_K)xA>*HAW{2JwO>tT_N^`nkMrWA(Vj4BJk^r17KOcx#x zEXYmVScu_pQaSPXe7m-qvny4HN9tqSp6FDoT-L$Oks}-*&zRc`nlYY!Kd{4RWXIc$ zoci5=w`8=t{iR_S-K3T;n5$2nG?r_xJZ(H|nrj*^}JX_?u-9J5hn}IY})3_JM zkbP+V^Qs*iUH(r-Kw0(Np1sK_ht>&RN;6NOHgE`3xPLagtKviWcN%)SowQpy45_M; zB8s~^=!;F#T+GzQJp^)2zMac`#7V<(rH*N@$YP=1uaO5-p-QJnm5FFD}v6mDw-)tiz}!arf=5e%+y7^4xAw%qL3T z&^&5r9uK}oxw5!V!fQey+K+i#1^vQ>&5W@!r+1w7XgC{R{X1>rxNL8hAIpY^gSG<2 zLdhG1a%DV|_-a-;qPbJ;Ls<<20wd|43W~)@=JD#wf6z|AhY>n(Z0Y*$d8LG`h(w{z zyuK~*njcTgbn~(s#6-^P7&v=SZ;QY6w2zXaL*(jx+<{Y~_^Bl7@+bgrv13R{)RCMY)EB@W-)!Y5H z>{#ZxI(OxVR({k&Y9AOR&)>D~FW}8n0|`HOdxomtC5uBx;CSE5Ep3^xaQW>Xf zvWooPN+~3f_;3FOrcfcb%u5aDGdjvX`ORkdvV6KWEjT=VqD5u>6G%Ar43jRC5p%9^u;HZA~4Wbh{>;M}BLcqsqP5hNn>FlD6 zlpAU_Q$UT5Htx$0%X5`5dcd{=0TC(?g(D_Zw}DjOyjJti1r!Y}gqKJz;W#V+HHMsF z5U}|vza>8*phI^p0LhRjXxGIqz z2YW2+zc6XgJqziA4}H7YIit9Wy-~cKp!}e)_mN)*g`O@lIZPSMLwMDI!@&hkP%;VT zcR2OXME$RvnKx)}a5OzcuLQ4$AZC7C2pQ6$puL9PKRoo2NIpykJzg;b@ixLqr71tDa9Y3X=!+9K>^ z+6xWc{^$aDscIuv=N${`-pG-mgjj^zq$0w}I%HjguqagifmuU1 zY{iAD+!v;Uc-hlFky+lw)hlZt^nm&ykNF$28n|K2eO56xCWaUiL(BoarHnFz8~|%v z%MmB(a?sGm`-=BLlCkZqIbMns+n+Jwn0k5B79k)Y4VINXKs|!q07*AX2d~n_)!OM8 zyAj$9Xf0CLIJq^7l)&nN71uC<$6U_W8PK?A9+dmqz3LR>;ShFd|Orb7ou z26&01lEo^-D<}>xcmy-TLk^V}Nzu4D{xE=;jDbv*hNck1QgGNYMj_-mh;}6fKH8r7 zM+RMZ#;$|dNC52kL)N|>yT6Gy{TEpViJT-+(10SZ*m_Qz1rOdw*j?f2ltNh?Lqm71 zx8M%Jno-h=0MP*c4AvbakL2CMkr_Kp0nHC@ZNVzbPHy}%vF^rmfT_&_ofT#rCy-Rpt%F4PlT_cU)On1g;uj$|2T_7?a>*UOGKie{l@g>KsCd0RRvhDO8YuSQ&Y~m5HiM zzV(`D#g4ZOYy-9OSJ(<}`|<)|n9z6zVMic=C;6o!|3)yGgb}bA9zC==-bj`r-GFZa z^;S+Z6s<%$g31>&m^bm=t2dD&{dz3k6f{^L8@SLziG})~e|5Fh^$ z#-}LPN66G<+X4zsGJ+-O6V@aSqZX`x%HsD}jwE{kqr4p!ZHaq$d2a(QLj8{s1`7D- z$Vf(sp(i;xx9xgCH|sDE)lU7E{Bm|N#MtJq%b_twHp^^0PJ1p7>KS*bUrgA}#6)O% z<*<4rcoS|B*`e1jmzy@wlhF&B?AO(w&@1TaH6cUu6!ro*jxf*)=<0%@nT8x=n82Tx zrDk)_R|(mUygM=->bNN0xsUL`00NL<+w@~d-duqTxq_ZadXLZ788C*zpqv~6gd-;- z1M{f^-EIMOwZO1ifpQ?;fRzm*D(JeX?|NUy@kCn3wYg(!Jl%HzW049UR91iv#205N z4@XHd1~HgI!{gtN#wEelzB=yTIaJ$tOc7y0{B>y?|e4_p#n_X5sb*U#XIxeNzzOF@w`6r!_{q1 z*OO5tf`EywBOZR5-9(iQs@c+XVfzK-13Y+v=^}}fKwn0>Y*7EuWuy8aE;o2a$(32> zyRiS;oR7ksCHKFB8A+CZVzHt6oclfgQ^5EiQhx5bp}wIm`MIG9HGo+#k%q4CCClE$`t zE_yePHys@%L?t%)abGPa8kK?oZkhQ+3?>sLQNnuh79RKU=wCgVOkQ{npwW z1kb+~zxYaA{`il;@dH%N_Mxwj7e_O;=n8!nVm`2V`uq^n(?Cb_@GHONCcaExyE@Fi zpJyvm$5oAce)>a6yIlUA-Y*dS(5-tCgdsQexY#-&b|@KUH?hoKic_@=3SajP`*~S?RgIWj-+q zXRN2l?c^;!Roty)F!Gq{yTd;BhYg~^sVO4b^pg)NtbX%EZuA&28!#{UaQ2gkwi4CV zTup~U|IdfJVhzQg-kiHz)V%TWiKmzN49ibwaPTU<6xXZoGCDpVa&9R6>izeZ-cEL2 zp1R8M)r;wSQQU|O#Tl%-O@6;$ERE@-wOSp zi?!bpb$#??pce&?RodZ60cweRSF4oL?e~vGa}9F}edgT1$Kn)QPEdDR`gFb<<7X?+ zdP)VRA0;telSPLq<^8@2Z{fL`buiiT0{hixd!j~F`=1zi>Q}_OP=V`3JO7S}(;lUBu2wzq8=mQ~tJDvNms*(WgS*yOJ>%vmW-pe`zT6 zc?X{&(~}{Z_2xsgUv>yFU5ZY->AY)JnO(~0)!8lYr8k&A58Ey-*Y+qx>1VUe`q;9d zs?2xus=KVby`QKnitHbY_*EU|kfzkm@Y3dRlff7MwM{jLE-_Ocq`Bnt^;~HB&}M_) z%bwHjW<5Vsq@9IOM!*Ov0CLt6>ny~}JQ(a7_=~jhVGQPz=r1mOHUsF2VR(lqSHopd zQjVZJmErJ3jw5HdIDA60^$P#ZBL9!zo4nC#WL(?swIB|}e;Ea6D+*(vBYext%dp)e zsz2b}NtY-%OnN?+o;nMjYyj|omtfW4h9I$Wq7Up!nXg9R_#_{j8Tdx%tJGVtbF;|Lh; z=@=S@?r#A_9_|2%@fC)lg7;76TeNd#n@oczq5%ks1Rnyx472GikOCGA78XbmoB;7$ z>aWC_5Z6&`c`rvDBGvf_#*4EP5Kyt>qjj3JPe55BrYhOVxK{t^SAi1_?(nu5)4;$6 z9T~{?52vIthap-FO^Ku}hCTM>7sK}a!5`H0s|ovM(mn!?1W{pTWPpV}!?ulBd;kna zTG7n69uc~$h$uaIBoWRf z3IwYsNq-0mn-PjuVW4y}v_NG*mI$g*@-sk4YvA7s;MBbIt)R@S@1Dks)WFysg}q#6 zOl4ZKp8PiJa)LsYpQ+@)na3>tCM4*F(f<>a5>UEEC0|B0WsJyfqVfZoon$tk@cCF` zn~)w3vthIX4(;81y;glq?PcWt=6R)IY5jsC)2Vp@#IZB)z zlXm^8x?{qIZdC+%n3wa}V32MG`(w}{3ZI()bQ z=_iy5#%Ovqt)Af`n6Od7>oNnpHwdp!BbNj$?j*WO9J3Zjoh3Ng{Xa_%;-tTc=K&cs zMy2$kD%v7+b)D|Ly}cP{1kv1lgi#7`{cG7&#U;#rhv@Coj5L8#2iEkMmb@ofveKcTZB7Arkjtn(}6&)!;^8nJI1i@I!q|Z{vkM%j{_fz z4v!&*xMbuarOA-@6yge0S@PRM>|wi2CKV|8`hPQv)A^x(=P$gDW#fT4OVY`x z%#rdeD>9zLJVvH?8qyaMO^PW~QP=!`D-IGCqOgD?C>X8q5OL~G&9X{rqjX6LL}e=T zubMy#1J`bp*6tLJZRuUZmJx0EF=@3R^yWezJZcv4Q(Cc^h_2ZisruiP7j}Tk+N2f+ zWrrd3IE1kM)!^VP?4Y1HAAyfL0y-z&VV(AirnzkF7b!Py z*`7_UAhE+7E(mD*1|#?O)0Sh_rztC{5shn3EJd-Z1cZcOJH;HNEa~u-1~-Y29h3f; z?q#C2NXAf61#B<8DX`i|=?hW(ps)<%)2dfvbN0ly87)}@1!w4S!lXklCe~zDV3-=f zAz@F9pIxRe_|~S1x8aYGrWkX`v3V*4&102i_lrN>q89ZLE`Q*tk^D&zaI2F>csXj2 zY(qTf3%B}YlJ0$KUM2tLACQ&6$fW0gQFs7*H2N1Q$8Tf|fcZO&?ZM;?gp9j*Go|Ie zTUl0O7$k`yWN9v?exDewoU8m z>imA>Kl1WXAss8wIfm4QGQl8Sq|0#t)FL zp=(Eb#|rVx$tWBGf&@qFSWu&pjT<5mOePJQ+x$voVU0mHi#BPkQe=>+K^TqRWxQl4 zuIexv+>Zy0W`4B;Un+6rFQ|b-$gN^N=-4GojJhHUWyWXo<+ENwBb z0#D+7PtT%|DSUe{N+1mi{F`trAGKN(P>TY|v&50QXf`3`aG#94`jd9~g^ibFl`_M( z2p-)k*A6qJ(K|4= zn6xD^Ue_(_)Q0LUrQ05Uiq7wQpC=|EJD;2H`XaXYm%+pvyEol-%`tTp{c#^SD4*Ip zi?pRU{CR!uQ|SlBES2n3gD}pQLado_4}|}8i&XR6JtsB!SU-4R$8t)&@yTX=ewXbL zf(wfDHr1B7v){igW>N7@p!LnJlVy(?*rd2`O?_YV8y16xcO0I__jyY1ZohRyCv#qFMel~K2y`L>8$hYJ*#k9A#Y4AzaWM3#*RPIcWF0R49-ht zBk^F2rOD#P&>4!2%B)(v?%&RH@F0%x_hG_^==?xj7v93>a>Z}(G~-j6e4L!Z zTlbgzth@HqGCePk0opW7Ic2CDwY+{Y6ZIB<-+}ulKhG^!9MbPM&8xxaH~KJ#L60*W zI%EmoU-2OWiPgfudVC=-Ge?NfhN0ohKpQ=@^6d!e5v_)X7+qdF-@q0(lU88!F_ibS zY5LW3=gyUmR@f+&b7bCA=VMff92giV@@Idc1#ot=FPCJfSz$pz;G4@~otX!2;h4m6 zpY-X&cI+LMm6a}53odU*!wM{RzMT2f%rPNB=XY8F_RgB^)Gz3i&d*(ufh(K-54S4K zx*B}MYoW&u+bH?Yc;}~otZ=y3>rtBUFzZD|DTkfVeQuUbj5L%HyiHbaoHUfuhW`C> ztRZg#mATvQ2*;^2(NT3tZ?vtL&T4l#>eSE9Ku6X8v(Kf?Y5GVceZhVi#t^s#ejZx; zHx$$`LJXgYJ6KZUnJS;092BtGD?uEhgE6s#x%_HTP#%xRup)3pnHw%zE?*z>?CyIN zV`va2p`p~{956Qc)?O@*mnmp?ZaJ|)rpjkV0wH{wb-nUj5haf+v#ikM78&duoH!}- zu17extMjwziO=9uy1cV{d^T8>_lxotzGyX^_g&rpN=7BfJ@@$7)~L+=SyJcErv8r4 zE-zq!k8?kJI>3YtV{)hRWGBNLPMNmv^SAio@{U)05{_=_=67MIO_+7bFgz!pTcAoj2PTD&SmSV=92m|psqHISp6>EqM}W6 zN3pSB+HB#Y`QZwEt&+;Vsm4(UsjyzuCMVkZU(r|QC7r`zD!Op^(4olK*q$~~lW@3s zyb*mGoHqV62$W!`65d_8-uLI-VXn+Gf&pU_^=4~wJ3CN1KF1OcLzc2oM zF*r2ibwW6HGZydu?MgA`*$-W%A7oA3<8EqcnL7M2eMTG?{kr7co+*u~@tFy!{-q)2 z-P1P-Or-(PS2E6cv=4n&w?>2=?XPI#KNsP)_UJ&Q<68dSR>6#Ru!AJNwm zJDl582RFXGU1u_OH%gWFG|L;Q0Ow3n52gu77g)%YzlhrBe$?l_g2~P}Tf^@AjNP9s zW<+tx(f#oAXES?NN!9h`^>}l97}C+}GZuMp?o)IV@8|0urA{Drc}#t1Te~yvNC<}j z_U=USbJ8ED+OEyp3~5!oXh*N4_%K501ZM95b#Pb5Tb+gq+dWD=j&n1@R_Gm7^T(0= zv8J$hrJ-#&KOpV>?RKF5uYJW!!G0e*i-Co29s@2($gX!TZ^&{Nj#QP*q}i*+%^407 zJi59h%fi2BoorSI87ICr%FSBE`wz|D=tflJwOHBG4Y#KqYYeMfW?W?Sv?j!k^uOCd zNA)6CAXOkZn&REI>A%OdGUH|zk$iF^KA^;1M+?U*oQ;ci#-<72F~G!W-6ZIOu7_FStzGLTehF!*PZs{gDC`PUV0Z>{d&Vz*cKh`jD+{o;UdJK=G=hQ#-u{zc*o*x)fzRK{=<#wqIO}w7i z)y3bb!vA-GbxTUX&B>fr4Xq)mIJ@8GGhXo)zB+fh%C%CSc^)XdrN6X)p?Gadc6Rfh zv^u$i>T=?cF+C2XJ@vRPs-AVI%wn^exO&K=HAeEz82J@Lc$v=^Vq@H{%kNL5}*6n|JSw-q1C z_1u}=g4|J^%ju=qt&`$Il7;8*QyUf43C-K8TZrw}>>2Qimjy1H(%Uz_uiDaqGCRCl zJZYEUAC-FnT~b*G(}blBmt=9TK0c{=B<3XD;G8wpdvJ%?%Bsh3$oNuTr_=^s zKRU5SEl0LRIIhm&-c;3^d6`Qy6FV_Wgc8 zcfFer)u&O7iKva+imxhjdr)qzWg&zFE zY0OqHxI5hBLO@8#zvAk^N*SGq&oivd6T61umfzyC%Gs&j;2Ym$q+Hdr%4g%twRwrZ z{Q8!N1K9yvD46U7GjfXCiaRx=cV6pL9DXmYsV?&S+%I*JjSOx%HP^Z#Ag@gvfWk&JsVcL`YloyM|6eRcFsPgQ{;z27heg#)HPPz1v)V^=#9>X&av_)_Lx`^t+zu zwnBR`>jM4B9LT`+-n`QhrHqRN0iV>xLi@xSdc1sp) z{ygXTTkN_dMCtsX@`yoNQ$EQ}hFR%F;q%Xje%`!uw8r+}+mx)zWIDr0RZ3(qAP%f` zVmSjn45|(;8W5Iserf9JVhIpKiDgfZ zy>RQkG~3}`_0Huvg>_YS*MWtK@7HQ3!m7ja7O`>o4$U@<{Pf6ju$rsoOxfRJ)o{7p zN~_nF<7Q{V7_Zt%86(@#Mv=QNj{V!(HP0=)`&s+Rx*~1gU}11tQM8~o3rx-Ut6bc6 z=rmO2@&B<|E?#`~X{CB(wYTerw47ss#ddd(j=Ado1slqk%HPNN&g$Bh`oD;p3zO^> zS2*$K&#M}(Ih&gm*QMKA4p0B?jFyvP=jiUY;Hu)3iF4ODA>Hp)`sCl^7h@V-)5Y{t zs*aMoV3oY$bu0G9`Z|%*ESZC=n}>h90U}GXEFOJ$(DdpHdiN(f{oi-;J3besOE27@ z?4-2dv~|_~T1%=G&4XO0BT3$0DYn%1tg56LC>Q0XS{ZLj&6d_T|GSR4L!M=0Mrz<$ zF7Ck8J9P6`HW-RB3|Ps%yLviSwp-jG%-ikNOH0wtC@Mv118<5)T;@Tl3XD=l%HbD= zI=THvr6+422%TJ?c48fE(UKtBt$4ds-*F~}mG!=~6<3W*oH;Yx)i-K0FnF%_rKz}b z`RX^p%=WViaO3v&z9_J#Ou@Ry^%=j@_037u7FW4;tUpr{c#ofJEW}cS%B@j8ob3TO z|Ih=;1NIj@976&rD9#7b_^YIRx@`O?%RF$JW&x4u3dI^XWW zBz9=Z-JyhLJzIrUXUw%^brBt|W#tSTigMN067qKMI}KW&HN5A4<^8_FTB(`2qmIwb z&*vHY1cxwHs`pn`#=?z2g=nYkGa^^J3K3vQ?Bv;bU8IZ!&*J< z=GE?c)?HzmR1%p-IG=WTKM1nB9kLSgZ#ex!=QVEC2#@nTfBe@wWDT4&VDl$-wNH%x z2@$-Gguohn{#+UfpF@B_b@T@19%p`~*K#1>_3b?dSpm+&{$Vd?I10{+ zO16IgQ5Y^MYdZ;BSKUu=l|}MSVjXdH#*q@F>|HM31 zP7YJZ8yYIVFp8s`;gE(x9#J==nUaMJy(Y7FdhOYLKeMcMKzB*z+yDM95~iVMqU6uV zl1rM@;NuQBwvQsl5ehY8C=D+7u}GgoOhLpW>RH94(3hW*pd=i`>4p&_4B?vL%}M6f zpvCTw&-#7x9HyY0;mkCYgaZ%qB(6i(`}fL#$$@mvbU(9-54-C$YKA7y-l^!&Q+_7e zbqIPOKJOA5ghY%>ud55Qr0n0xCo^H=2xx}r8gZ2X{bQ~QsYbt)hmC1Q%cw)*8i&cw zBAX4?TCy-kCN7TPHv9k`*a1l3Vw37`HIsP0pRSO^szAX>^s!JeV6?y--hj-$nI*0k zv4Z`JWn+CKrHhPeQ5B;pp2AaQvuysKr!^pW#Z4#35|FObEE~fcx`Hf9Oa=M*8zQR! zG-Gs)#?)0dS#wOxz1{i6Ai6jafQ`6<(B+~%axK-K$*QVL6~>0qof(yEj^mh~F^@7LI+=QkeA?{oF;1Q0LN`sR~VACi{g+{5$Y% zc`!ib^(_q1FXIM-Qi~}pKz!uz*#A%Z9=KeClNbBtr+>O8Iwn=ImfksvwlC)XeS~IX zl0h!yS)uzb$-v|B6BHY$h=>F?0l)#t1VQ@)ey{!?^)9D2F3=EYd^^;rGt=9reoN1z z{Ua$mXg6W=IurM%I~H|0AYuk6fB_uCbEdIiV085+pHi0ufC5Z=$T!uZVB!U%RSLbB zWA|Gx81_W+LVGbgFx2?ln-8=ZmcF2UaJvZl%GN z60(XyB%P0W86&{}n*)GGGYkv>9HI|Ls4Ad6!WqWO!tkMUU&$r-Ans)f0*bF8z~bX%wzLDlqmqUky()0)3&xh2CGCFm&tE&|@o{vXdeW08>QM$b*pbZlRUa#EIO6*aG$k4Uqdx^-Xp(j$ zsFOufS^>$AUNYjH_WOmt48YbfSd|1ucJ)51dmp3~oJUqL&`1xLb4dH#CV_)ECkGLz z3HTPNox#zkrScn;BPQ)cdx2+(xDX*_!|n##5jsC8;vnsVfaF%tL%TGt#33KG=jQV&~!?O^Z1n{_YLO{q_x(&_4A>-?4(d+7^A0@fCi z8<8I*iJjythGFKgZ3ye+#()`v6)1q-fIkSXVw--+!Djd(Vn`DJ8FVtfOTwGre$iuH zj>!@J21sH`a9so&s}l3wB)a4NCd9`N26(4g?jX`)m|12_?i3O+|3~9Uth{jBa6d4n z{#M*`$rRakuxfJ=infU7!?J+S0S+~MlaFOnl9=4bIcYfek?<7|nr=dLiOq6(Kw2I* z2tt>OPMs;5HnEiw$K(Vq@QT zkoDrT>=91o*lQjuia&x}42+L3oFU0>zfhGWTS^4Ck2xk9k*FA4#Bs45Fz8raytNif z92v;m(e)U{OELTPUf~K6J(#}?5e zMUbWjy;~K=@q8#{$m7s|?L{aQg0aRTfjK#3noaUnU<-I_-BZhc@cSIbl5RMF#OAWI zm~MMGWUX(Xd;Q+WtM&yph9oJvc6RZGmE}0oKjD=#1MLxED=B8OsBs5^M;l0trER#FZ7>1pwjl(RGZ~DXOP%6x#O9J#+Nst}Q zK+PjFTQYPP*hG|`7_G`e#6LYyE`H!2ie6YL&aJHut(ig~6jsejNukko;f2cq2qYD? z@5xPY$P|)92Eswa@P2c6gP&DA5LDvvK=h0-2d zMxd1lsQyK?ek9fooQ8mEq(Os8s#obM3b7UH?{(TNqI7jlmxvEfd{%L~mNpA&+gegO zpRf6&OIWC2sS?Q>aU?|9l~ivJABSsYpQ%VKi=P8w?fO@1^d~FubmY5#vRFpzqn^^* z&)sXa2ponk*>Rzr=v~*AQ`SCXPn9Ri!n)Q4>7=IMjBe37&#|54Eb#!4+L02~!T2D@o z9}2`|2T#MN+)stGH$kr*3zPYt3QNNa`|F`5hE1jv>&5F7Xi*vq?&jnWvPT1_`(2W} zi^5K_!l+9EWE-p-QtiY16-;Iz7N~euf5q1PFk1U<1b8L$Xi7qs8VW4(sFIx?r8=<& zBqrIvetlw(!#Ew3aqIx^z#u?_M0SC`qtC5&fcgjpR6pS+w&@oo-6otGd|>dhsSx)m z5PT756Opw+chKCai)~K!-+UKQiY0X)?-xCE>AGYAI6Y;quU;si@zWmd<0{{vHZ#*>yQCY_V*5u5jZ z6{FCh?mWO%mOJjuWA*Rl{AL=r_S$!YwP(4keG`12thgR;W>aH6YG$shsz7sbxb~28 z*3X}8Ki>xH$_!Gn_zdiobEuhC*|f=$k-4f@zLoK(mh)$!x8GiyYG2;Ta(n+2dnZp+ zX~&80+y|HBx5kb0tD_lTWbpCz(Q8y;+ENo{@<&F-xs}&7`vpzW4y=PSY!3pnU-^EB zuNoGPX_Oi|Wn4ksvyM@jxr!n~mXW2aInQ$QmJA!lHT5IqtnNSQVsjr|<`|e&eZcti z=dClA!-9_5uGWW>wV94Le<`){PUOEJd$2jnMECB;{$O(tiI;`~+Ig*7#xc}gBAp59 z^MQx7o;gGgWXff%kE}md-=Uk{zBj4$f!eUg!GF)%6or{ZFH`CJ1r0MrIBYC9=)&`x z)yhps^VfzC54DteXQg+S&v|mijWaSIwc#EOq5S3=amrkCMeE*sir>*6rq3(T1)9Be ziWIL4xfQjOd4mo`j~uNyrJ`nl=$kzj&jMa0tPfOZi8L(=Q)}JUl)P`xwh#$9PX6-q z6Yg2^)>P-()*oCxwbFQp+wCK+9nB?*TlY8jaM6ky@m3!Ge6K`or~0oL>PCZUmDcSY zb=%ujo_jyJz;z>kYk7@wm(j?hhTl5TO{oSRt$%}4p0W$yeGpWjU1vGXB|{a+>64bQ z-YjYC_;rJ+!MkxZoD}VnwC$G1k{QJ`ZiS`4FV6b0J*2Gpwy<;jFcU+Sv9kOpK3~O@ zudmL&6hE70sdD&Je!=}i%!fo-m^*ofr3_vjeHCEwK)G(WX-xFDxtoh?P7nPq3VcZ! ze8eL$9 zd`TEu;e#{EH`bb^jieR=xomu)9EDExzZttin_e-ZM}azILfP$B{CG(8-W{Pt*MJNI z-z+P_RzhC})oR5*`j@y{<89i{`{sq(%H z+l?~;9G6uPo4ce2;OY{6G81p+GPuD=Vq${+7&MoV{Vf0>30Q)DEmm94<*bpWCO!6m zM>&1`)Pc{Y(6V~{?H?+DFreZ3?(hcU{c>dB`UdpK*vgI>!!*j%rQjn~6@hHXiAz8l zY|YT&pm-of3%qPkL-}s!HMf&L_r<}y^f`R4SCIWbIAWI^U27V9x68TeJblD>vvv_W z_#!VuZb}3O5$18IvA`qtL(DA%0Wdc5!!Gu-D`oW_2tC8-UsZF-W`|lBB?BVPhqo^Z^A+PRiPHKU%9x5T?^wV~9@gi}SfWG*JXim;#s|nyL z7(x9n9ddYPUKGFH}C%_KCw31xUSLWj>ni)nsvYG%t_=|GzWeVd=}W$sgucn{611?mWQZ#_aC0n) z$kNbAfTxDz5yX8lP~FhBETe&+R36&CZ5u??otgV_#K8nFPQ3u#JpUYCFABF^vVO88 zujeZv_)Rm(_%5Te`$nEMyfa)BB%f$ry2Ob@k!<`pOtQRwh8V%V7x?v^v(Vb?c8#iA^tR?h&j45BNun$qX&iMoxoHG*c2Te2PmJs#%axv*K$8Zr5zR^Y zL{K&h$o+jl-rK2dEj%3Xmm+wVR}>=E#>~q3muR{zila2 zHP+W{#XWFug3L@1?4x1yCL<*74XzKsG!Y1b9)gP$2%y93nV1P3B%zo;T**jKQ>xbX& zJTxmhUq0q1`JkRViYj&eFiiILM5-3K*LZH9p8RASpA~%wT_>Ea_>3ot!CAaa?bOJL2)2s~3u&6R!p!XXrwC z*yzN-zb5b2pPr)`8(p82rVXf-4iJmyU4FBPRDICp+=MU?+qJZySIGhG``=|<&>oRr zBAzuwCxl2r(r5=Mzp{7ftz`#5MXoH-B7%7dJd5DT3^XEW&xr52|JL1YIAD647-_o= zok5)=eE|N-T>Gt6=J5g3*WWFq!1sMt`RgS7KDE%+8xK-DIy!2zWS*i9Ov?P#p|sT= z?Vj`Qzd~gjUn^2nW*lKCjv1w0Z1+vhUA=P25xVYd;h(KRIq@x+(KqgD!WJOnZpMJ$2)l zfY#4%^-rXLgO#7Uic6%{B}-WTf?mlX(*p!@Jz-FsG+75~N+$2-(R zQx&CbC)OB0%x4is-L4v$*vq+NYiYNhhR*uty9QNz)oY8!XU8RVYQ;};yXJ>K4Z8Mq zb^BKP(wKmg=?NMf1OIq^S^wyta+tWi7Rg<%r82+SOv8!!!9=~z9>xb>jxV1sZZ>gF zZGZ1ww$~`jhJCH8bByz{)Tggr>-9qwx(e;@nMyxfA5?kDvyA!_Wmbin&3V-i<6OS` zPH}?CkWZDPGh$9FH&4Z;UrHY&5kK zEUJ9)x})F17MFT>RIcDk8Kar%^n04Mqd}Zqk#oES+3r6>qnqNt@;R$2c?wayS%1J< zhj-}1JK-Y=0WVp$u09sN`^cyJ(c;k-E~;MfN>?HC)(kZ#KSiemUrx~o$BhLR%RfSk z=XE^-PPUv<8+ZLAqCeInobqk^I}7c&+!||{?X;)PIf>|GjW#k;fRC}Eubn2hW=7I3 z*YlJ}U*!`XClTFKMm!Y4dpxBi{K59#8#+%$%bu6+I9+y7Uv;6nsin;FRN#9$!^HS|>x%iJnx%6{Ei(Q-y z&AB#P#W(rF9(Videv{Z1(6iQI0<1al%VqHWh5h~}Y?OWq*-t}E*us8hnjYvcsRH|B z9Zpj06_}9aUuYH7GVa+`Qj?jJ!%S;?B8NX(O}&ekfo9&4MVI9c=&~|HZ^5xD{wcz; zNfaUHB-a)qOIH{Pxwo!jRzao^Xq>cI=E$HEGiji4M4lB=5nusP z^@!~j!+wTB*PoZKv?Z!>M-D%~^=}{YT%FV0i5(LX(<*ujs7{LOhXm;LJ zmWWh=z(`Mg^E@Q7TgK%sioS%45Ztl(*B)D2)&f4JBFi}j^VR=iEP#MrbeqT|-S+oF z(^u8=!g>X_F-^y~tr7c@BPvxTOuWcD#AmdC?4fe`TLPj%A8Ro>!5tlO<|v0-&R751 zMw~%%a&SyWCnWTn>WS_ILWK>Ju5SFy0#0N&HIfKiM1`ZTZJG62l!2y28ES9h2Z~Qm zT$vMZtuFsZ@Rf=Ory!P4)ltX7u4uV`Xw|=G${a61J_`vQMYEg$FsE5m9|mZNz|9zf zlrI!ty>I$d67@a7Bmj4s1KL@EmXJ9~GQ&aP~d+W;v>vFhl z;Q+%h=wRg!#TV+#G9%5xb>=%A4{122eU_P|S&~fqbiCg|3@CmasRaaN91}$USEj& z-h8O9CcI%p)2hK~c4o3>CMu%jnVkVsU2-h)cw6@C7tz2Z4KOK@HA}k93!uy(=mFT# z_o3)6a2P$gIy1ByiU2Z5N90n#ozbguJJerlXLp23CTAW88x+Ho8eJHL5j-u?xm%A| z_F|m*6H?th!Q5o|00_ZUj~EIQITTD&{7l#g3`(92)G)ZwH(}I?W78ZW_5VDeVIYU< zsKE2D%j$ggvCltU0qnX~UBY@IQGJoB8G{%5+;2c}!OxTVHr7&N^4{0U#-V55nuW7_ zyG42u_Mhz!LW<41E1Ub`a?W46xfeWigi0h6t5D(_1^DIX!L?rpiEB2w5X3Bxz-{=P z#03C$O%J*wA6e}n zbuL@LA&AEJzlJr>*vGM+lBh*wWL|Xjm5w!avrn8{*Vk4=JO!Y3#&43b0rKX!QI}(r zBAqd*sz_8e5Mf42A9y-lgGigST4>Aww4%(%z&Q5rU-F)S*b|@Bg8gLVM?f@~rxG{p z3wR#M;IGzqBZ+3i-%r4TaID*MRA82)B)3Rh#mvB%q3OCp#eO0!!`5|af8#e&m*pxKKfw+UrDsRUuHDuvCXe7c~;b-we26rO|l zEYc~$HmM1dc0mDw>ta?F(#!y1(OoR}^;93n4n~|VkQ(_7L!0bJj{>n^#qhp`Vcbq8 zJmFsRJiubfpcID1SFnWwY=!LV^<4WgEx@7J`@WHl0csy&ihvBt77T$G8dF&xXC%6jf44u=8%gcM4gg4(&gqd4K z7u$3}4{>HA?a5f8V;J*;h?JBe)2*ys5{R2LeiOV6j#ukqTYl4q+@i_VwtaEdm6ccR zwYYM4Z8HA%>~ZTm@A+!$Jw{jO5TC;6QMgKA~Q7mhEljP{Oj z2<=(ekXzs9>N#R|af5Qv@uZmZQOcK8F9-8AZ4rvn;#xd;Fqwksla*L!O*X@Z^~okN z57q5NXpY8mt(N#sh^bWzdei?qSpYE6W5lp%toD}5J9%yz<;m4~ z@0DVFft|Cq+-qg5`GL^uPv<(cw_yJuGID?cHhfg_C#BSWsDj8qM>C&?`C{dakoC%0F3Cp3h}@$REWfyn)W=y$gp{ z+~y%l+8gD!Bla0{9${3e6&%ex#4hn~ujY!z#(`9?)IH_~n!D{X@47!wy{r{bEtacC z^`R(Z3BGfBB*l2FJD>#}j`LWnD zF-Dz3K{lspjxsg0iYl<4rXMfcw09-?Vp+@H*QwsxL!T%v=L`$Q(3~wlu>A2^bK5#b2pRrcyC2Wf?eJQP$A! z=BHf~q}|9TuwK8|$HM!rd*E&Nys!6;{^g}-&=Iv?bc}S5+$ty#)%Tp)Yme1^ttkd$ zlj9l%7H$kH+n?`9xf|)KQ+P5^aL1&!pO?@6%5<}p``7ek)##ddm#a@Yyy)xw{9yh0 zIQ0dagX0m>8IgVI#RoY)PWi0kUh=PCG>d-k<>!CmvgH>4`&IXg+@3eyq^P{lq-Q!< z^i<7jT3+m>jcDWsrth0*AD^b&c27zA@xL8k=R^&&?pKWoM;_?4n$U4x>Y?h_AFBlY%qKftWcn7JzFVufMZ%X0?a%8_wv@A$5z)4i74?syUe~Q zz%EL2yj5T4WR-tJPkoSyZ{l9=v6C~d&m_jje^VuIVqZK>D{AmsJsz}zN6xmpcG6zo z8WXWux%JVT^U!9?oH;(&@ri<)Ygm|tagh6}U)UwPFYHkF$B3ABgpFq`Ts^HUb0RSD zz#eLqqKGh;)|a{hBb`1_u;gu{uycq$|ccX!kiCoq%Ya~F+D0Lu7R?yBCtsx0~+!W zunMZLEYFlID>_bDNS5mf0UQe3WfT#o0CNKTAv-Cyb8LPj9Ta|pm0%|zp_**U4U+G- zqhbYG(^fuRLn?Q$2Y_f@Ylxx$vt}_{S@$I8UCCPi-+4!esKj~29ebT<_Rb5bFj7+d zm3(#KL_lo*-7k_fVaBy$(yi7M`C9|MPX^Gvi{!2n(%_|~ooH&BFrS?bPpabO3X`6_ zVK{WdaCR=7%c8Wn`;vR^nAi8eVaH+51pFBue3G7jBoN49f?A6EVM=0REAA1@2e;1j z6^~&5Ki~LoB@s-eEAS&NvK!(g0lQ-E3pUd(GN?DxJ5}NNhY4gBX1|z@eLGLV1=~#m zU6MH&z#j}jN%;Y5L!dCD7D+rzK@2cffxoNjsZ*w8ta~BHGEba;%tr*oZQ4@1|qgI*2!Aypg>u{Fi1)0@Sz?Xn1YGRQZZ zN=9@K4KOz3&k&=poSq&Z42KXw0JOQIjWwD3(`P#f~3a6e;FwsQu zVxS21{4R|4h!ZFnYU+H@@oy7hSk~VZ1*FzoiiOGS{EnC=u|0`J+J;1tF0_ zaopPq_QJQrZe6`flrxH(1XwoFc8AMlfzO7z3kI?83N5r?+XFH2R~)^GxWXq=CK?oD z*NF2N0}!rK```ocqwBz03R>{ zzGAhZ7(RiyK4NIN4j;yM30w_+{1@UzJ$(1slV{d165d{c!x1oqbR>TOZ6-w|R_^PM z|NbX@g5g5}E(dXt!*vRtUHV^eL2te(HN4`$kNb&&at_9&5%=yX1I7kV=z%&P+;&6~ z-hyE?2Rlpz$UiST7Txro?gwS=mofkdcxxzvZx3D@!(%~GB9HYh;yJ+>LfUy!8ShwR zN9e_}>pn>`HxQ!;ORefJJD;v`Sx4>^3@W1V-kqI`EuRn% zA2i>(D~7jv_NrZ_x?C}M?a#u@zHf7Afr0?nZhi8B1TvFQ@Ec}dI+rV*&n!T}R}esxDZ;qVqH%Bra^JwKnv zf)CLQz#yd8h7Z2K*h>@?d{ifUy1SX^0|fW)XZZ)+`R#)rGXJ4H-D?(w-mDj;FK76P z#;@aqwsdYujhK0(fxJ^zCW*yOVkfasO;F|*T^o#t&6?)_cuA9NqQ=%%VnqPrWAy*X zb)Nn4+CcVi{1{fGE5`REQUnhlJ_pIW*`y2bL&?-MU9#R-m=e32_7g{_!4^}^u~u1u)?kspFqnD+2b^n3`&umDOy?jp7; z(!#(H08Q+GZjUn@hR8;PM+-}vG_~+Jk=#XJdlr_;MCl6FLtvUe40_4)0mm`0WJb|F zFwdUH0|$2~;=+bL@fSe(VmySUdaU^Njj z|2mk}?L{*S#yZil1ewlT{QCP36F@Q#z~8_p2T&0dytdm^cD!|hU|;2q7Tl3EU zWrnGzB#D#aiw?r0@+17#iXWmF@kIjNVP_>w5x7B;1R(f0kUUV_B_uMQ!OW)LuI4Ss zI=+a&gn0r6WSFACT8_;0;Nw!NCxc|1xv$Y$V&FW3ZwK4|uNZK?nEW*M?q>&rR^cTN zuZ7Ftr=2(d27}6wD%+DMVc_8d2l2L)V_FckCxmJx{p)?GM6F+Bser+V<5(Mk%U;*l z!y$n7jv}6GAbF>7UzU+O)H&@2paJH}wNF2xXu+<3m1NcMWSOpt8 z5U6q#AZs(ph(zwFaJDA%BkZt`A3a*yedd^*80GncHPstJkN!$&zLaCBKh7BRV^Mo0 zpiY@5*qB1#5oMe5|3}kzfOFaQ{Uc`=1z$R-VvSyYP1%Fa$wDP_xuvdXNivNsJ4 zBO@gvA{s(MS|Zv0pYwkH?|U5Y@jl1%-goi)UDtVjzn@wDS6*Oe*Oxoll2^>>^)?wX zO8Lr0pAV#a+gIH+TNt)oJYSM)M7Y?~$ge-Vo?n1lIEnT*Uo7jgD3yF#z!SHy&$sDb zmN2wXueZ|e6xI$vOa-&|GKd2+5$#JA;*}QKWt_xrWuf4DL0se69W$(z&r#;1S zX-P|5#WeBkQDwH-&y#Xc&&%1jb+bNiU^Bk7#vtV6nsCav&UrPKs>Ou+ zz!st4Q)yx$Z!dE6@7_4lY8mgrT+)+U^=`v?U@GnPZ~hDzD|t!?6e zCA0pBKf7h8?^$@|f#!r^lJ{};Gh1W~t`$FhT&K<+=PsO(w^MG+<(KUW1K;XyyWB7L z_U3o$Npf`ssj>gQ!#KRA5Rw$G!u8{`29Kz>V2ns>kc9Pta;Z$Ne_}jurHbQq0Tc&)(K z;>dkc1HuRHfx^G1Umw|imQOd{%EFS1)zPEGh~K?sr1C)H;HSu%t5Gel_m%{v4H{jG zSBiSVNqcFjQNO}=%C_-yoYK*yiyTAUa(|^RTMY38(KT&)z@qprY|SrBKNLmw^!uK~ zI8umE1ozl|Goz-wrg_{3bB#*nuq~|p|Ilp4`&>E1jYy0NL4jph=Y{0 zG`J2u+g&LQSu8C95{V!*NPjL6{5ttZaRm`HkL0Kbv$C}Met#E_u#hOpJ*K7vvmqJW zIGzxY_Zzb!8?I)Y|C8t<$C$yHUI+1HIc7QFZv4XBvyr+vIAs11wJ*qcOd+Iz4dycv zwmGt`loRkgxnp3{Ft)UOg@9JrPgdI`1t6CJ7)k>8vC?3)sKo4tU53K5p~y@hm|rjv zmjYWxYsGTnO3tgJNLsH1eu=P0bd_(e{dVH=Uk=FJTHwL6-0gdxK2`yQ*dpT_4*^O!gGn+2vRSp<*-Wy#b<~N(Nrk9$D??CC6}SGyqII zfaxe)n-n0E#MVb>=)yQ9yl2lg%vktK5&y`t_`9N_f{b*eSPnx5F2Jz7vF#3* zgo@%>@K$R_p?3I(eYoNhgh@|C%XPj@E(T-xXSM|{pT&etSotuv)(?l5e;JCB*7Qq~(xMYPvk1`n+7n$eZ*i5!JmIUWQ8c6VAFSSkWymgS| zFp&w%!Xhm$10)O|`DS4qz-x-#PTM{kFm#Q3e{}8*q*O5M($F{|>`#79OoQbXB_L3c zL%=1ao5KPP^7CPSS=mzV(Mr>zV?diJEwC2H^S|9KPU@g>&$|p%ifAhCD1%lG3grXX z^qi}>Q2>s=zxnK6d07Kie7{<(`pMo2PRHeHQI(*5I*i1?VMrw8B|ZW?L-mddND}JA zgxYk*W*d}Z%=&g!s|w+D*y>ljW6 zOpL*rMDrrd6~|EY!m*R#Sk4gdjk$O8=e%dn;0Ol(a2o^Ou&??KXoJbCBV$p`OyP<5 zAFnqOO0l`T!UsSJx@(>+r>Z~y=r1Nb7!Ocw_P3w!PY)=LT!;fezfFjpK_m8c!9 zNXx=B{=w@Xht95x8l>L|DnomfbfSPl^dNO+a9i@R<*BzH7yY)&%F+git4QK&h3*cA zB)-Lz21pAlCALg9sYVj1p4 z;8lzTf2HJ50B+&u(|C# zY^V|ck>+>7s}XDU?dgX1_-;MNnE>BU`frX)_$Mfu<`*g)aOvP)kM|*`-)KNCs%5@c*N|EHX?+-|&Jy|&LhP zvXMI8cd{9qn%LGpC!^n^7S?| zrr@VtktONerj$d|a>1AF9MfYO1arbT*vI-01RE%b@8S4YOEG-an5s4XOU&x6>U_#4 zkG1^HhyCVbV40Rk@rhX*y)44be@R}eS>r=c80+BxxsS(!-OA7PD?MN@bMA^QXp4I$ z7|nV_-uY#dVRhF->7Z=4?^oYFF5z-;d_B8h&n{B<$35k1!tuuQ>e|1{?3lLxmeSxl zTExFzxLqbTpQXgd=5NjggNMQ!Dl8H{JW$=x#*^M;o8bM?wRJ}qmF|y)*lAkfs@%5k zO1I4Z2FVP?eCaY*@W}|fTf(wq5E7On_2g^s=-ZX}^<&Pu9glP`S?+ti{hrGnhJ?L> z-s)jTas-J_4Uy1kNx3KQ z!+t)|sw4#=wCJPL}}g;RNv`8=6)-v~sNkW{r|u^f~_7W*T@Igu7Hn$-a$Q z+~`(rF!&-@q`oZrsZPTO-NaW3KUhEA5?GE9tg~6&|K{C}ZVQ^D2e{8Xqa6Qn?8($V zaT^+|Ao^V=*8l#QbU3!~;Z2Ji)_+^J(5vp0YucM@BVxt3SEVgQ!ExkGG1rJl;m}`} zzfwjCMv}&+drd`RPAkynh3)p}d$;cqEqh(`NY4{JVK&>|j;ni{KePnpM5MNuQs#cY z=WJ4+cB*l(FUTYL_njCQ&3%d0N$w%B@7$FR%N@C!*{dgS?j*gSi_cCi^XHualjHvB zCZ>-y=;cJ)UT{V|99D^$WNi^PNI%MB>KL&>G`3{DQG1yO_W>!PZd#hoPy9D83H4pm zZI@BHZ&of?Sg|9`kv6t~w!OqL#=u#)3eT2t$l#;bCBlX>#a!PNHjgXuHyh~N$qDOo zoncb5s#G~BXyI?5OYcw{ruZ>L{*&g||D40$@_E z-Pts{Iq?`xt5dfvjBVk)^@k~PXkzeVr8Fab=z+U#zQ;M9`YateM00sKJvmwL;K6I} z!U*iPf!07@KNH8~2dmd%KEMuNbNb`F!=>2`0RD8@#lPvd4WC*}R30VHK?~rf^G)sqmH2M*@<-zC)ygJQN4nj+pfj}lj?t6!x?)hwq!21OCY73+nyp~CUKfAP!>?)X zq!KGrqy=EGHgGun!k3_K9=R|SxOVyfi@Lnez8)z|E^ax_{=`ogkEp1snhA$4%}$19i$lOp5y@1F|Dd9R`9_JPvs+nZ zxLD!}#7Pbbu?X_mA<^0|Og|gej!OfO`rT*GZ2S@@+PTq_4TN9DH%*35^3Q}KgV5h^kk1mt03I{!l z8aEw`{$&1$*Z)%ZE4(9vSm6lv2FvObEG6JWZXqN-zKJAXg3;J)x{SR6;2T7_vXYr* zXvNafqQYYhfJ@zXBIu)+7>W=8y8#yiQU_iIxngpuhP!6F_F;NPn&&G>@B(g@KF$^= zu0m{{k@Cq&Ny62O*s8F<;g2Sz2%VjIu=`_Ej=yn32U-t2J!25Qi|*LLjfgf}dqCuH zLB7HN7G3kyOi5f?3BbLMemem2h}^mQYRn!W{MH>l)J1Oxi!WOLI)KW|(!xS5I2Zyb z8YreOKE+~5+RWnfi=i-?pB^Qe2~=uN2YefPifA~1`HEk!w>h%1Iwl7^mXefIJM(n| zetGy^v1fU$tJVrmjk*;qGZ&1AT-68Na zVHws@*0d#H*q|litNWc6Yywx#PRke?8F5tZ$2*H!6-@IFA` z&m9Vn*}$IuON(3YMo7QX+OMCi*gM8>w(kqTbgrG}q|ZIRg+-jM&;OZ_4P_v|2i0UCFY3K%th)oIGwAkW51Q~w@Oa!G2!b~Q63wli7Z{(Mfx(Z`w z5~@KN+YBxT0A~jt&k*+=^M8rGv8htd?}69~YA=)T0)PWr08xc@ zyEIMul0tL^uMcV-DqA*hx5=%o$=0bVZeMx~&w$rPdR2j7(`|%c%;!6}&sW>E?nwb%+Kix0hTbG4iDUe1E zF|l#I*4ovN&rEKc46#iG_BH4ozX>2>6n_D!V?h9(^OQT*1Q4L%N`*ELbCINDhieZc z7FSO93vBVkD??DVA%@GB-g~H2=zhR@A$I6C_5AUr97wkQ>-no&y$JFe9MD_G(zPG) zYuCdHHFL8hazum_;`Tq?-_Fd~*P$CHOm%Xf_A8~snhN{%<&EC>7sBo!`|KQ>XbDkJi0e+xNzDOz$#pfRgYJo9u1jE!IW~svF~;XpE|HJc$U^ za#%ZQ#&_cN(J!sp1&4aYwRt#VDNY+`6nJsGr&FBUBibt-CaX<%$5i{x$3P|R@*-~k z%|i4K2FlN=W>d2}Mua^r;SN385-46w$y{f}p~XnwTfSd7U*+*hiW;7b2j=DH8tkVDU$q_3q0( zqtnz$v||kXdX@YAHc)b-6%ZP#WroIc4>2>Ddvqa->H_y z#V#3qsz!3z)|IPG+LhuYdz^{%wC)#Q?wU3+ZBwQ4-XvL}_x#IGLwF9?oW9)LTr;)f zVyfIpZYBEq*L#_$musTD=p#e^IWoI%-g!k^|HMD zTevRFJ%)G3!5QlU%So}{XSp2i2S1!DZrY&a@XM{`lUR)ZW#%@~;rE+G{nVnR_cro< zZ<V*0qjR{Kg~as34i(LJ zMl?_uKM<{ReyQk^;4g19wC=LaO-&#D`8l`S+TA=HXQb9m_glz2x8ZTEs8(T-zPc=z zTqK!bl*oPN6dkjAo~>b`_r^;5DX-q#my5^gc>nEf9OGTBn?A75R9R>i6APe$Rsx*B z?*%oAB8ChibcF{AG-)RSMo4r9)I?BKKuOSqzzLY=c9W?V^5Oe(^)cFlz(mkw6oq7x zI~sRz9AyCP1Iyom6J96bc$p8iXT@z?uCw}o3#D{{GbO2@m zl5^wM4jJv(s?({30ioDQG$@gnsJ3FAzfvCfwXOaVFfQoX| ztUY(G>PGJ$s&R>K9?eJ%V?M^H-tqCNj1%Zi9LEvv1e*n}6l;(B?aqp_$p=ilN^PE| zK@CRKVp3&K8o^gf%|7z~uYQ?Sb^*8sAPL5wHPZ;zIPH+J0S1i=@z z%*sSJQd5y@jwC!_G$z~x$xFcPKKZt2GZ96TetG1?oJ85dB|zJP0{?|S0Q@A>H!z-d zTqXQG+VP%{ntrM0Ip})LS!^;cBkemJ|ciNalb*NvjY=Fy6P!TA9@T-cvjcN zrKYywXODb7Yzv4G(iqw<)+nI6AV2J!cLE~E=o{K#YYQfxtU=_9L5Rf+?4|!HkkHQ! zX+{fVbAjx$8zPVG81H8_7-wc7~r!6=Rc&NZ% z+laU#ZH_epDa&yQ%gSD}d3w=KA@v;!&j1u}=GQYcjFB@7#1~JN{9{F?<$?G(_d$nn z>d7=M?CGEhhzS?oR7iG+c+pi;-A5{({KvdO<+bE8=gTRLQS zwn{u;0U@R8!LvaMweT+*v(r5OD4X!(;*d2|QYdGL6x^`eS4C)<;?jg}tJed|pc1UA z6#CHeakTII81dSzhI~a0mnfKuA7kxy!H0!~0{Rm0Tx$Mv#c_GQYFkO}hVN8gEKKuq z7{HuApGPl(Uk9;@LYIQ3_5sk0V8o~F^z&WM>4=UvoQDNcV?;Ng!nM8gyfmTR?H+eCz*j5DC zfbs=m0Cw{?Hk31qv#w~3houAM<17r}<`2EID0j)`$FO^lrZQ+t>OKY!l>yaYnH$PSA)VLb@TjqeC6 z0l|Z>)mzOF1Ocf!Qf!G5HXx+pza{i>kf0F4#d;mK>oocsu1zV>wTPpP}FYTx$t!T-A&ueZ$(RY@b(G$k$v{vGRG`2k4kM8JXw z0X)EHke`Sw)A;%2ahTzf5Q=V?Jw@Doz{?$hba=eKAB$^(82jJg&RkupTT3TRkppM9 zU&kFm=rtT9NK^b@nqGSB!mCl7%B0R5H#4KDDRJ~*k@vj+vx}RzHtHg{elS^?czR(3 zrw)Rimbm$c*}kBxQ}ybz``5JvvA4p?rwN`93li2j6FJ{?!4zmSAo^8gzoYq)NRZ5h zE+OV|dFsU42wWyC;mV+q2yKs@8sN+V>hb0w1*kBW#`<%`ACkA|lXFWeL_LXp*S5m^ zA-vl}?u=J#1NP$+ie*A&{jadqg;CoHv%1cCO?a%3y9w9}5D~JpI@oOAipx(;#`1pC z_zXRFB`o1*z|HxfGU?N$86!k6fFK422*ojwi#s)C-L0R{B`z400v&D_D37&33g*xI zb1cfDmEY@+SfBf%7N!3v#NS=BsOzRuy;h}5@pb*2Xe+72`Qmlswu+VeTI#mPuu|)t zDAD3hr@k%v`^y{QJce#Y$+p09hwXAgslt~YYt27n5O7-I|7Y)Yvfd_>J8@}F?>8S~ zj&A);_r_1p!+0nR{`$aOEA+d6MJe>uOWw=berIMFwNPGG_gX@(6_5AMH zdtn{g@(VG-b=;v(H5J=_|9bmGw)mEEU#MdOzpp|1UH#t83v%y%q$q3%i_TFhW^&LG9PZCn+uwe*&M<*GQZ zAhc8D$@_it|H6jzc1Nsv>Q{WW*<7FC)YRP>Bf%RZA=>LDn%yrECN?nbwQ+d$!*$jo zw;+4%=E0T_i36fXO1jGBgO++Ijz|L+kIFES8|NyombdqE2-;3j)It$8pRF%IFOn@E zE-dADKP>mQnb z&UDJ6E=;*KOx{3n_Oi)L?*QwYpT9h)9rfEH8YEd5A#*X^zw^@SUT zN5tRGJPmu9t$*ydC#}ITgNMz*->Lk?7=w@68D+C*CPen>7MvJuUAe$4v9Wg_^Sq>1 z-?HBq$8slbSMTNh+;wZQe6fH&euEmWy!h7*^Dk<^w%`h0m>CDOqYY9Oyso|*U_D^6 zB8B6!HcwuFer%#LD5L0AbtY&|b#)TJDdXX9(|&4xzX?7_rPjTpr;gxU{sF2pj?RKXB;fiGsrYfVe=;3IkOI z3=gme(P3@xRW0rKj`Y~yX;LSZ8XNVXX{+mraGD&+>>|~_w*RN9z>l23ZoqH z!9+zJlnJmJn?jiDGWMOxzLG8Y(Kz^Bpjn5AX1s27sxGZf+{)|A!hz>gf8T&51@EvI znp_ajtetrgvhIfgBw|~DQkDHiHkm{*SHJslZd!I?%Zc5T#Z~gNOS|#wbncoV#lP(I z;nQyK~mxQa=keo|uUN z0W!I4sK7VgldZ!@yJ=fHhTdSlFa5LkdtP4xJ8U$(^oPjbafh7w)+-?}kdrRp4_n`# zKNzRAbUm^0nWbcg9xG8pyap^H!?t&2{!a@)l63LSJ44NiD74DJwE$2$WD!m`n6uhb zdbXtvI4T!xHeh%Ng^b~Zo-o(&b9-VnI zH|Tx|HMb*{D3iZ{nG3=eez>-S5KxGjV(^Yga-Mkf@A^kdMCO4}UNT`PMmJo}6Qmg=dCd{41k*G^-)#th$m_rPZG zI0Qrnse?>N-ITLa{crY8f|ZP;XA>or@|Z#%LHtJ`3QUvs*3d@EW8ImAm4%RnAOc5h z#0*pFAx)9zA&bFj6o^v*>v(}R<`8%*!f?9x6|+Pse$2EitK2Rl(`-+|$w`(Hwg0~L zqhMKJ6st2oW2JuQv6U=26kre!&xqal^VhqIP{2QU8@Axr)Hg5~#%oJek8U_Ys21&s zS&uXugxi01-99=8_$;ctqjO2RkoP-B7p{ZN(i zjXi70cs@`J{@=Pgk>b&OkUqx*DN#*TOM^Rq&}f8^|I& z4fS{>P(eh+{-nBgW&BFVr`~-hrFVnXA;Tol3W43+n>XvhFyXfzC5+>pcTQqGCrM{8 z_CbJ7jP(TOhwiO77RNRYBevpDM1qAvF$RTe3l3eddE^zu?}9hZ05=FcR`@D=AW?xw zgJ?8KraBhI42{6sc$!I;1*kzy#BXTd!DC;AQ_NWSFN8H$$Im>nB8pf9#^VJV|ae@4)NXg|G5KaazvY>Jq{$|Yr#oDe>3>4+ZePyM0bcynxxNRWuY*#F99iz zFjOWaX=BI1TN5I`2VJLEHgF3*6NB#MF7$_Fm4s2D0>=ed;8KPlYAVXX`<}0`(&N`6 z0MWmpz&o?Evqf_?&&GKGpy3O`-?F;cUuPpwQ&dIC#>)B-b(iEkQb}iucpeOx1!e(~ z*MtK`K4Rsa8m9Qs$-Fzx=*Yv%4fB!#eld#FnVA_9MFw-+NG+%~Ehm0?d1LltTD+xh=+FbNkED?g)F?QoGz3Xe#G#UeMD?1Q z2Hi`(czWf-d9!3cvV&_k(7H^OKaxjd^S(jFios1zlZ7`vy6tWna%?@D&eTTLFlt=*$t!Gw zft<8sm3F%(<2_DV!J5NMQDu+jiexw4Q!!W*ev#V|)7*FWQNge4S2n~;-ngpBGkterkMtHUt}z|ZQ3zV|PP*Eu%u26QCihO|g@VL&b!Fxkc_TKw^-l#e z|FUUPOD~8f{PeNy=#{L^*+yeZ*~V|@t#AH^N%)}8;lDQ{^q>BjatfCXJa#9#CMPOM zqshfEN+GX^N!*t6OZOQ>1j90&TAC0 zuh#K+@w@*!zvHm*@p-!^mi(hP{Oqr!7>6*duGq=0K8t_1R%tIKC&*peVPpK1`Alqh z(~!EH^9$-j+WD4WgscFaue{XOVUH}hWYDJ?BFg;_H*stvvEKqes<^14lPEgs0>{dzRu~&xg~f8&lo3t-t=jo#x7=k zuSjm&mcDNNhlW|F{h4)+tmZT^X0{nBOD1Xb1v$jUhWSf#^*B0Sjgc4#k}z%l)P1ns zR`JG~i?9jrTj73PE@M+8*8QfY-yIAa*L0=}v_B3~D7Wy)J#=LMyZ-l)*~h}Jzn;IL z$Qq$bn0dM9hJC|Q8FPmKr}Pfm;*Umt_rnkKaD3-|osz($En7H$ExaL1pr>%X^l{zZ z{QbFmyMw~E7#dx4kSsLNnJ(0qustZ3MH@2cBogquBEEW!wl1oxe3#}sUY+Qb`1{A| zGb9`)?KP(h<`0HN-8PN9ZAyPD*7l{+@OQ7f>R(nG@(Ly`+?c8IvhkV5ojrGy$8S^k z={d9VrZ!00@@Bnryqrc9iyUGSj4u?f+4PG(%@|>%;~)L=u-an$jr-L2YQ;-~6;P^y55WII1HnCvoS9RulCjUJ2&A1FA z_yb1j{KA)5{>7!GzaZj;5m!D-+wh5XHdX;8C)h4JjotNOg?GU$X1BG}MT``TtG$Ou?75y!^E|9I)FkoA1lH>*P`w~=du$A{tL62Nx+9Ac7{7}H8! zSI>YxiTaX&4AhTa(Ef|HsJOI^AQqI|n9cTCn4o}x8O`5?!E?X<<;j0TU0o!~CXkf^ z58o-&?&kU3!@Wyl=0QMXP9&luaANo`j^o3V6U5}9LgK5yP3OWs*|=4G>qgZX-tLAv z*aiq+wRuB5f$*?4Y9Bemii=Z&0I{%$n!~!rq(%hc0BN0T#_j8ZqycwveNQTci1|If*|42f6 z;5+}BiNbGZt}j*nOFeY(>ki>yW=F|PLsqJ6ItC8v@hrm0bIpALx`29QI4ukrvPfMT9mn;?EXx^`O z^4HL=T|?sw0lxraVj?A`P@;bd;<^J(cgVI^x&81rI%4X-MG*&EE1A1UCKITBXg;h? zH*-^KNO@em<%Zo-plyjXs;Oec|`4W`_VQ$bO_}Ud3hPF z=oO;rzRK1=lxs`j`R67Jv4HLx`cRv+52jlZ9+(5fd$VM>gM`ehtxdz9iBKFki-;lEx_`%oKe^1Bs!J%GKl{k* zJdNAnk8=PGn;_WOSdW7V<-aN*?uyxlz0GG!?ztii2MM;|Q&?F>|Ern{`R@sapW}GB zEoIhW{m9J^*59iusu^NL%G4{Yf-f~x7da|f$fsV3-R*bKRr#y!3Rp_IUxQvOp=rihZH8m zG7jdw8;3d`pN<)pO7>K9lBWZ=i^zY#IyD9KeNEt4{(BfyPuV1=v1l|IK23*eUab6t%9!@>(XDYXaThYd;3n8fdzEH zjICAIko%yppnzq7^6MZQAK&#Jlsx*5O#5kjv+>Wl7rm9NmLs`%1d~wp$ctHMHp!+{2%Uv}~OvUa>D$2K!J&$83`?4eO`h3@j7 zv$Se=7c_5RyK z@8>V%@IU+aTz}!@WOsF({540$+>e|Q8FUso2S;o9nK~CMzPvDJ-B743lbNHZu6ORh zxhBQ6$A&%!<{wCfsp!T#(p&Oz_ls}2Ex3N}Cnsy*H?18dEX%r++OLEkiR^nkoUtjw zCFe@1e&(cb>^UK^(80ptp}U_izrCJTYByaSwP)$9L-F}0hDNUAmOUTd`8a03BTofB5^UW>eee3A8^2jwRfyE%3j`v74e#+rh{DjMJOE)_q}Bm$O2@ z$cq*=dv*D{S{qLqT={y?vCfU|#d%X+Zb4JY;UgUWw`uflbNKTy4)gL{W0Za|%^Lii zA>N0hE#N+dPexWE{tgkT?G=f0SJ_WZ1ekyPvqSNV@BlNX*_P_kU`5`@Apfv+1EN`a zz8ULE%(wXnwXxpZKyOXspA{q1u;1jUL-LxqOG(<^Sn*qz^M0oH6qNN}a0mStUyHo-jz7_SYi=r6Z6_Rr>4eKyJ02J31HCpT@H%p z3-i5456OcAM~(%szfq0VQ4IjX>qzzo&=I0T!9b=6s|&QDh5#aPdrWL4S_7PO1h_;< z8>tC;glP*pcfp;47@6OinoT1Qm}D3r^atpX9k2xw3Uc+ta>)OW@Fo&LRGPhnwetM1 z^kUXMaRw6E1{^5lF)sMg@ryNkQFA$El%GQvjmi#i~qvvx7# zLIW;F4&MtzM!zCoheD2Z{1n00e1n|7hPyg{eBUW>)c4)6_h)Z4JW!N6*WZr5&HlX^m4mD zbhmgkoFGd8$Uz^Qg5*eGGJfDCiJX)SsRTVFY82cUWT+ysm>76axdBoW39D10Glskcg`8 zD99atgWDh}P7qx`2gI+Vo8n5WUgYBv7DibT zHZP*xSRUIRXa`grRQPF#b($~bA@%|dEk|u^K8ZZ^9DK0DGCeIV1*IhYR%e(HH3S6i z>NDm3KXbtq@HP;8F-d#@;f+B6g!dTme-bSZ>AgGz4MYfs|E2@W7#5rt=U+-rnR_MP zYjLicMki)86f0zq0^Qt9Y66g`C@k(}xL7`S@GDHLWUC?@4e;eh2+buyhIog_{Mi`K zcqJeOYYGNRJR9Yn?^Un}Ah;DGmxT0mscL%JXY=-mXK8@AY5wD{06pB<#4L!|9O6r% zAYqTo-%o0cf>CpZ@Ydc-48n^g)HjieFTBrHxZV-9Ns_(5aD72|lf&#p_XlV}fP5zt zak8NQWJFr(3M5J{WXOj5Bfq%#3o=(mAxaRux}l?^0s9$pN%Y}x05d{@3(4XJt;F%e zoHhM=`uq2@^013?InTjA9nJh*MY$SpUZ&H`!X=!xt!nR$GIJe{ofKxxLZw3pgWB%h z))ZpALW`7HdNe0HSzx0jL;@!J|I#!8ibtMLKS^iZk4X0a3-JY|)!NQf74W0~3BpZd zRX@EzHnOXy=NOJQJK*xr`$K?rN4e_9PB*s{&(mO5$iNSAA8E9~>lQtkbItko4FsR* z>FQpg2=3%ycCN7?L9uwQ5Va3M`6H|WWM)dX@}b8^MM7RNzF9kZr=Bf_{t?!jaY#*w z&FiaG1mFpO@D<~=w`U?LrMC)f@Ca7q62EPc2sUcH3^x@>&N6e<|A|gXEmOBKA zV(F=N>u}O=yq0p@m{hNkgAeooT>mx~Cq?$G%pNrsJM$rS+io7vANY@v*fc*NlnO+uu%aT{+K!;42yc@~ z)(JbIgcK$IgJKqWJ@)f8{&ocF9UfULP)@o``{lRMkYUi_jTSVw?o(02})keds(IKMhOO7WBhOOSqh`y%y&|hysvvs zV@4&DpQEG`OEMUa?`5{yh~Ac2#o(tb z-%kA7r*d*hf}zigbL(yy42?H8h4`&a=ymqS`-p_Jnl@h7iRZidJLbXYllo|xPOW^U z#a~f$dJRXICJ*_0|E74cd_Ete(9-7C#`-_Q*YCQF|Ve5Uy^it^aOZLY~&x) zzoMx8$@`ug8jrG2d0b< zAz|doy8?Gl_w@Taiq;GGSSYrfidenrxUWd-=wn(o5lyBE?R}De9;s}j*AMCIeBN}z zt@WmmgH&Ail;^uMB@+U-K1YeWg>oL??W3i0FJUe*rjfDD`(s+9Kb}3=E0Qh4W1+9A z8C@E^`G)s~O2L7>=k$eR&fTW5ly0Fdvg9Au6=QP0)Ws4u^>EIBtNz_W{JA+I#QmHfCsx^P?4=$ zCW6=k+}==%j^CwHG%{v~%uOZTNPFO1HR?@p4Z}(fmi;KgY?0$jf}~080PIZbk06(} zu<+cs^+r9IUvOMlS$){iQ*BhcDba_dHIS+~{D7Ji8}I3#^7l-=_5%{M4$C~g=k4k` zf-Gnh7yG7o<^v2hEI}l*79b!vI30{cm`J6q9-?8D+)Si_5i=!`Z{l`k`6z9`N+tsg zW+Xrk*DCNaSj=qY#J7F~D=f(11=SY; zX>WIaP@VAkru#=1~%Q zX1~-=?97cCY9))c@AIsDHfp5w|xJ}4o*Yt***?7@f9N3@dclL?u2 z4-PfPMM!WD(?xFPEeth$G^^<9OS~WjWiNL;3EP<#VK6rZ<@g_h7o zwKEUtAhvP{nWIEM)U6gtCZ0XyX)G>Y#QTQ|KM>_&EtCx2 zxU+=}SOF3_5p)W}2{JGskVc$%B*ievYXf$;031d=N;NS3k(Dfq31kUGDS)IrBx#bM zp1q3BZ)Qmto5JlqmwAkO;1?A^?vANxIK?bhG;^p<0hn|2o)1zx)7RIBA*PJ3wXA@p z2BuwtgM;yF@~cC5_XPAWFw>F80diT)@ZrJ+BrXpDa|w9VGVmNJK)?UVZ-qoX5T`9*Sofv*8J_@>BgJ1~2qgeBg)VeiQAwInJRw9>`7?`>E~N&5oHncDer z0Ot-?5JcEjx%bP)^HS-`wc{*#h?Y^1G~azZC)@&6gXj2qpvUG$cs>EvF!ER5zyAk; zCw~z0(A-;qk&WbZcz@Q!(T|7I#fpY1nCnMoQE6>7{Ac(dFcd>l;?g{^#JL(02tZ)H zx5}$udc{kDbp`wfSlbw_nobN;b6Y$R_d79{bCPY;R0lAJ zScP`HxI*K8h+FCKI-c51!?``7y1qUyWlq0il*uD;i0Z*%Imv)+~{i z^jZ`4sWa-k@JqxkA5f(qy`_JS?h=-8 z_VpcuJcWd&cAvFYlkPC%JF4V|9_OIp$i=d?#j+V){T1dOrC4xo(-54KtsOJh01P2j z&aX#cg|Su_H!4D07Dg~qUufNb0fEVU z7ytw<4iOR;ta9p;A?zcVnKmHk2#Yn=spjxJhmXPrtQ}HScWP@h`g(imc&?B{>l15> z{X|lP-wq~OP?N9#`7<#$6&R9%7J8#z+Q;$oovry8VmG*PyqKc z$j=oQOG|Sy!t@cu=iGZ1GL(7Hxr8Rg42L=5kd;zke)w}c#D=#)e*-rK=A8uS=4bt# zbW`1%&P(Vmons2u8h<*GYMerZmE_3ApvUZK7jiG{9x@}K-v@cSH)cBYw0n=#pJong{C@QygBhVez zo;?Y|Bsvd>2&6`8H!*iW@&v2{0Zl-| z(9|8umFEQx2F|`^Lh}{TmzCc)NjqR7M&Q*o z?lRsxsB?xYzz(H}0Hsh&cW&<%c82hLpuYephSB>H$-SA_8$i;*B?Zk}GwCk^oOY0( zITOPWY)0|R%gYCox_>NCIMB51GBIxz$6ZJrR-Lx zp?+U$6lZuZj3xKr^2WnJ?pYK`(L`sH3}SS3YIq}^FlrU9HY#?ieG#7U?#M6YT{x(Y z_+VHhMO0P)t=}PQExQ>3+n5aE;^Tw+r^X1>f@nat1_oT(BTwwyav`?q?0waEWQgeu)Urwr) zAn&Bbm4K`)$I`P{fH1R0;Cv$nti4+nF8_T$XvxpaY$1Xv5kj(X{Hi3h8iV({Nsr@} zmIrWvz_>=dYDBhx4;>Cf1FW7TG|MD_^tMY#Ae2mM?g1V*Op{?Q_v!C{9UH@>ikLta zs&3~@eY|k4uC545Kjg@O*B&8ZVWk=LFRdZ7gw7X+DRY!x!$P|YU7DCfN*{gYAr>Ow zR6^grc&oX%y0Xmr;25yQe?vqq|BH%i#$?T}^_H_cLE+FtBdT(_<+@&_)>z1i!`R$ei+|cPGu!4)>Zn9(VVg1S%8)fFnqN|s zf@dYwM1(@D-Tae1o@n^0g}6tuL~SiF?e^5YP!O8C!^?4Jz6^6|Otyq)d=pD(&K}Q; zTRZj~6ReRmI(5ZmMLH%);#RD-Dox!j_4uQb!=c7P3g*!=4`lRo55~v1jQ2>!^St(+ zZ#K{{T(700%|9hzpB;bp^?drIpNiT{tk1;P`Au8uu0DISVg1?hi%}&@pRqp9&#vo-M1S$>Kg2nDLMAgyC&ov8dpX;psN3OW1J7%2Mvb3Z4bLTR8fS~eABh3tx`Y*Lv;M0Qd{Hf2-EE)=26jL23*H&AIId9y}exykIh z_a+iKJ~7bj`13mPw=by{IHV&-c9!F zM@+YdvqxxYT#8YRcXj@DiKo*sbCe?1?iSmjsZ1;Q6M;7{D}B!?HU}T9m0#h1RmadrRs-*}CT98@xe`Uj!Mm6l(?}RX7IT&llXC zv~o4(__ZP`NuR0YtEH$-=l)ditL(grq_2T$^v}xZovhMaDPx~`nBV2g^-IvP58cTQTxxI z+t%H6aC_x04S@}NYgy(d&)+pj9950h_;KH7#6kYFfScy7LS3mPb@t@!h7OS{cfx~D z-(`r-9_UUD?dk|}jSsR+U7smEkS~;QshQ(|n!{ZOrS%VZ3?c(XYvg{PZd+HGDBZ5W zIlqf(uYa{3!;X-E6?t8XES}=F$RdHL@7Lb$Td0uw7;s2UuE3wCz6I#}qaLj{C=PZX`Zv_4Nr1d(3y^6Z~&AV&;dP!E1Q@3oy{9!69=yh zRWmkb0pD%9m>T%>DH$?bgt@pXXU|>%%vHKD_yDSo1lD^I!XIH`LgwhYz=Nc8LxF{k zbW^jX_V1;FDL)kfIbbNB01rq;5KaSPJcayQJ@K1GM0|RBEZ?}Avi7xW8U1yWk~U`{ z>W9k*31G*`@q!{1viNZLGpo;SoalRNB1x%sK`^e-uXT*=oiR!PI5f6$FRde1a&U_u zCT_HT`Jw=CIer!5iUNqhW3^M$SK()CiOM`dCGjZ~!v5JT{1H-2F2{ zDB#J+kG^ils3V6C`3Lem{Vd4M2m#G9)Z&8h--6fZ1w}U%;cKg8e>GY#(faM`ziYD7 z?VAk`bf>JWZ(Dv^|NW=K@#uy>sBdo1Pku%-59|i?3S=#~iMx>tM4TRgX-Eu!(`e|& zV4^&S&d&Gd6<}%jnKaL}RbYX6(8j9=^QpC>3}<8FoN;hJF5aQ1_Ruj2L$kmpFMk0CYDo}g($F9A@m6>UIRiUTTJa*>d~ov}~$D!NK21CcE?h#Kh%fG{kcJm{5= zFAhKSe4zgFSWujRDp`#f{Xv2OVwwOOzn{4Apb@eJv_|{xn!x7=k33Xn_ZuITC0foaGhz0YStQfSY7c1o-~F7%_`h<1un5+1ByjxEq};SIV=(4At#DT$Z^q!Lw1b@i{wn;PI)63noThQ1!`=5 zKGHJ?cLG|O%p8Jv8de^N9@+K3pt0L|{P;7p{3Ivt{P``SqM}f)FhcVQ`zaYNj^>3j zQ*3@<&XFIZ3^7>pZiVUYw6|3QB%cy*0$`Qkm&m8ye1~5!Yp6}vMD4A|3M_SUkB8k{ zo<>G)fibEI&Hp*Jd)OQ9l1RpfN=Ywv%^}$%5|Nq24M=V-1yl5;dTM(E0|Tq8S6L`H z5Ct&+bUy*h6bJ<%?rPN0Bv*2%;R`(_|4TlupcZ6!kV%rH*|e}oE%YQI=_F@?$&YAD z)(xF$1^A5dRbVdMFtd_U1-!jn$7D?|2ne*@xBBnU0w+#H3?b=KVbF9+fB4w~`Szs8 zLpn&$hNd=&AZ*-_H{eqZA{qs{!~n+grU7=CUx5+9Z0yYgknuUCipBZq2aR5+&Ig^MI`cl>tG z=G~Iyc4W6~^0~<~_o7U%rI;F7o(5CBtkZMs{X1@^{n7{sw@;h$| zy~_qF<+b-5FCjv4 zI{Q!3ze9694&DYT6x9BlM2#mpPU9XDoQO88o zt=k+F`9@C9uU=(BYgLEmwTAr0afLFjJ5-y$c{2vY4M*AVJIc1&Grk$xZYH09azT&n zyNbk?FD;CA?;P`3ueop@OTRQFI>b4^n|s{h{83JZE1so+ANc(XUM1Jc$y9{%4vbzJ z72!4WI=GYlyB?dfk8ccbnAo0%_~i8S6?|rmFFD-9_HKT0UPB})PyIT3Ca2glGp4;! zG(0hmp^O~|w+*Z`>+QCV@mb52bBm=aDmn8r%ihwM=TN*^?VC`$BC2cYKANeUrCHWL z*6{qgN!aDX!1n9%kDcZGw3khUI`o-<19_PO1KyDwS}zA_mFTb>yqDBo@xIpj;Ec=iAcsSXyVY^1? zQ4R?wX|Lcz@`CoI8_U=u>V(=aZ}ex4QFz;l4eqd{656dLl8pu}O zTf9V;KPEHipZS}RC`8CqnA8m=|LrUJB?gst6{IsA=1yZ6vp~db#n~^x|DXhKNdJ0jA${+RnZ=CDOgI;6DipZDUpK4wLoh%}aBOuzw_M32 zD{N#KtcAjBx6cVRj97@rJiJ$nDCi3zVJDHV$TlSjW+reFwGu0 zh*AwF8SpEzmy=n9zFSK(JbpGO`Q6WTf|V3WVUqL!8OLW?=R6OftA&V1NuBxV|5P7zbqLQ#j#LyJ}a`12qJS zNFZ+n{YCVUjSXrWK$FZehsEEcv|0cBiiI@@#r_*5IQ*ffZ6>o*pvdroK_16a6uLfK zUmQzhSj_iGgN4OLPjyr|>FIw^))TBOEzaXVHwo(JD0rVqL5C8m`s(Y5{4b0DM++bZ z4;Gox4Xo`Qj2#U~M%baHtD{4bFY$nz*{H&Z?jDy7I%_xb61zh%^BoyD+hD?*N$4*;t?jHhhUYE!TlZ`p*rHYu*b-^N@N%R z6u~#&!~F_kP?qMeP&t}{h=a1z8%}c~>?Vo~GR_@*Ki`v$VbSY^wMa6iff<0_-Z*6PF6IRX)@**=2JVFpVdqd*TPa`z1iiaxjx;tNhDXQdaTI0 zc$n+nb@m~e3SBw*#6-nKIl2Sw(TTOi#xY1_z}fwp0@NOlGXw1rB66jPo*o-~Q1-Te z7s*KRLZCUwsJh-$%zj`lA3`i-O+}}M!Nm@ukGd%g1A}mi0c=r@Gzb5LxJS|>xGg#$ z?R;+o2T974`}#u?Bq(3J6|OAd7Rdbio$|iF$!)@W#iBp3DVaD7r3t*LZa^m1@pUF7 zCsSjd6cj>4Knj;XQtgR^g6NEqE)J$5=_4KbR-)1Yi+~|94Ok{voMt$h5N>qk-T(MsAs&S0B6r08<<>+1y`z z9*T_*F8B%1W-8(D$-an83S#XgTOwlP24Ou4ycKfm#!Uo>UOXxPat+(DKEia+RSl5n zNth}qnYC|Wz2F#d0F{Jb6!IWI_;zf+b@qk>XQCzi(ZE8Ep;vN46pd@_NmiJzaBI;| zkt8*o+$~?eyvGOuAF##91nFD;aRlq1T1N*Oz#(*Wkgni5k>(MTLWR)~6JsV0S)8t{ zNwXLwfPB*}6$Jkykv)KWh@yUVWr>6_;=Cn%2h62FQuOf&(0>vuX@o7-8#t0^;*Nv- zm>@&gC!-w;KFsFuG6#kl-t;V!pMZz4iw$0mlH!9C@gyqdqZq$E!i0O2*;G=Z?tk912_ddBC zWX!DH!m`iNtUj{YIoF~#c{SfVO@mVVcR5$Q^J(T<;e?16)HjrV2fk+2yvn4>d#ebY1&lX@~Gy|l~C}hFdu0NzIKx~q(mTrJLueeX2|oz&DPEL zC2pvzwOo~u_fw!el+1d3yJC{4eE(9xJ^Dzts0f2^lHFo_O|roc8G~Ij6zUlFg$v3A zGCW~MRWDJ>!)tYSq=aAhC_1&INiT<+SWMRch5-BZ9ivqv;OQW{Uzi0N7;Az zt4_60Z*Aw^U_^U5T%qn86ZJ~?&hVZ;In1BO?^0Nvfi*bvJaK420u~!~1n5|l*uQ)eNhFg@UJ@|8tQMy(u zMj+7HD@C;JY-Db7t!cCW9$2p8nj(+!Z~nG5Wi`iR*+iu9cGvy4m3Lg9u)EIROia@j zcyPJNF-3JWK;6zH;$4NxNa_BE(sg?iIaT{8);;8ok9fY|A|jA1Wxx7fuHAbY4Pd(y z=e6v>AAawQ@7YC#9ZRHj9;_+;&DI@sotcr_J?rwZu%;KE9}Q`6_5Nl}zh9jx9{y#^ zOWVk^HcM6}lAK@8$3_O3arb1w^A(=sIKXM8JmY;xtvcHLXw6%loxg)xtgeq!sFXI` zxe~#n9s6#Q4sA2T$)q7e>((E?1MN&3U&kLBTR6Puxd!9dl7(7q{bj0kS63wz{sj3l zY5Y91wRul0dg4fh4F*Rzzw#%w|9o|wK{h6=EitiqS^s`=My%wL!N(_Cm4_n&*mj%M zhvx~@D~M>zrD)^@>-B8WOfAVv`Plh~ar4b9wtJ$^zTY;Z*8IX`8?PH*m$0K+{we%y znW z1E&@g*tltW9v|yIRhMP2Hz;tOQ%qWV`Vx2Hzpi(b7apij%w$<^;diV%&pWg?7xuIW z2X>v-i?7doV@a6U{Nd*?GZ{P{Y}!hVjizsEU&u?Ovz_>YG^KDH%tg%TI< z_b6{3PuqTQlvTX=*8TQtdMtsNNSxlhPh;Blkl4eExo?D;(zksz+Ie!NFT-m3?T0hc zOyWHM48y(}DK2AAA(j#%zJwqV>GJ~^dXQLLd_htO(g0jPL@$l<)P|2uacO#n( zGz)yu9z-8H{hsyU)~j0L7lCDx9s${2V0y82LFR{Q;}cS~q@_x`AFxNijXS-yAIE;> zi-FnOL^z_Nq9SSg>Nbu5>{=v47W+EPV~X8Q=3qMzTug$_pubay3>f3!M4-z4vcKW+ zM~Kr6C?6F`w8t5C_bxB&J4kKxLxCZaYJ>g#?aG8L@m*IQjEjx^+~_eU4s;T#jUgNZ zKz+Z20@SUeWYQ-#vjbFZ^5yxHqhy+r|flb zOXwhfAi)AJyjxJ3JU&Jm_iP!dFE3ZS<$nNcuy}9r#}Nh#V1^kxkB70)e9_ zB3V~t5<#Mz;y*o1dP7}q@n3Bbxs#;OoT9)Kj6dsyaAV$2~a zk)Q@56$tuDG;M@jf)u_qMMQHN1j}=hLXWlNQ4xuuE4$>*Vv&6Sow3#T*E$P81@w9tx-d%Qerh zu2Tmx(p}fqUUyryE-_Fsh%+2utGFw4U+F-=u4s-!d7*sMTir{;-aqd+T2yrKd-c<+ z4^o9pTz_parJlaPyiF-ggqMfr38$~sl^=hN+NVN}T@#NDwrp6sb(3Y_S6Rou%vt^K zTD@_Y3|T&3Hbr{LK}jZrhaX0CRqo>y9Kl4WuyM=Ye&|<*;rJq41u58YrJ=h*MhKDF zBiOoSs0Fem^3m5;`qsYtu&VcyXu3k!s!A>EkhqFs9D~bLa|WkVt;C=&00=#*cqkaf zBK%cH3o-a7MbVE%l9{@)6()f^_Z! zs^iUaH{F_z<{xK!7AKB#@vez-Fl=Wt;_7ccgY-eMC^GL21y&SV5v~f!1nvFm0*}nX45j-igu`#SNaSM!i>{AmAlR zv6(f#Eb$M3^Jg?B2N=8wgiB0SW#vS^K}N-)*(Gv%2(l>1f=8tbb^%Wdq6Sb>620tP zBGZ`^O@fba&kHzl_ ze4CM?!FlZVD_!NQkd%-jx$3}86nUf$MRiG{#oX|fk|Co)H4X)wfA%=W^xY*F77GCu zOORengrhY1*YcQzk`Zf%m5q%cq5<(6z>0>XEHB^qiXDoKk%jW`I}*n(BQzJa6jH<| zq!!TGm0pohOM)Meq;NqpMVeehYh$ntRirtR)`t|gxg4WfT=76hSAz+q$ns&|9b@<3 z*bfpRg~x%gcG@BzUSh*XR003}xJZhCFO4 zbCi2g+xdUENGM-IEknix+6jUi3w(Lf_T%qJ3@;griT5O63OCaX`~-P`fsDU$f*;vF z5uIv9B+}5KIXT{+RgmENp82w25ZJ*^u%aMNNEvurJz=#;=j=b%H{2Qcl4vDQVZSF8 zFSv=$0yRa&dn11*igUIga2dXUN0?#y2(;IqEL-;nzIqu~m=_PT1OXASW5IPg|IF1Aa@siVI|Ijo%E5o>+gJk(bag_#59PqdB zM~^-rU>b)XYE;pqYs)jwsvQ3JNF)6#c|naNC0s2FA7-FC7<}&EIC2T=h4gdi>~mpv zAn|F9QV!@>^31Ae;3y!sbNjJqG<0M%J?0I{;`{_HL}vPsYx_s+vUde9c5LVeVTmZj z@mo4GhmN0QcpUj=M@6KiWzUykBbollAL8_sR%<|TB}O5)Gq5ZG zS_VJ0DfUOw34!0v{(xB(kI{b{5n^Moz5so|Et(JOD2yEo)kn%JfOMWF1U(HSA3Vp} zBHuL{Yi}6?vB9(bb3HSW^L#J8Sb$n$CoP|eoign%Q2x-)icZ^@F%cUsJ_8J>imejleLYu4JjF2abN`ZB?=^^s zE0Gg*dOua>>(`GkSK|qawK%eIk<=SutyExqXUMR#)2Hj#F)=GMb6XcykE^fFbVf&9 zi;0*YQtx#aO^(03C-%#k%B&O$n=iTa(e5{w-zT=3=5W|(80(v!pRlr=bn+6)&-SG+ z5ViCW3t-*qBe--xL^><&@Ds+SKiUy#sE3KbKk4X!BKL51ZL^<~Yk67Abvjh8Ix3Y&G%BF2HJ|?Oy(@K;w zUHq;}RM+0fFJw^{RjfOeD(6t9_>1bL%Vyj4+y+g{rVMHFq3#~cSqidyT4m`nrDbad zsjX*N>1v##FBYjEFx&7exM68+pMlE<)Izx9$ zD2vy}K>>BW(D^prJ9%8y%%4*jNaTcmhG*~d}_<9<2&$IACK?$(nz?d{d|F~WAk1X z8~2zwT^5DMyUH>topa^twGRHu<-L}^$^E{sXQqrg-}+`3O`omK7Fp}Rt!8+vU*ljb zYckp~apMk)&pZFvhGt32oT72J9EQD&ndkjFFR=C<;1S)%w|44cdv}@Dszj~~cl7oL zrjv}}2W$O|6B8f5;H6gg<1H3Y>`p8kzdk+GkaUK9Z9_IcMU8db`PBGFR{I!rU)sge z?lct1y2yA)x8kbP{Ver_zt8hBEkn~^tGW%g?`d|UFV~XzWMmV|9Co*pb_=Df-QC~| z!#|UarkGcnlAUcBHvgRbwwX!bNLTVz>kdx+@6y^ag}kfTs^bq1vTVz;QS83A{?4h` zexbM$7peZ7_7FDN?;~5p%!{(JK0kOqc}pYc!$oia``7D2qE5nB#jh~s=>=j*)|2@M zK*RQ17a$xVI5^mbDHP=r*>MRhNKB$w%~Kk7dXa5=PHDExd=&bH^?H-Gk0eL|s@Z;f zG{N}ZYm_`tHIVQPhy{>X*=H-;hw1<+A0OdnD|=9lMnH9P>Zpz)|z}elXZV`}lEEHqFtf$-NJ<>b#vAPv0xOUiXp5ZtRc^;uxUH z0PJN3R~&&wX%%!xxGjkpAaEOk+3qXwk;dVq?e4pMQ^bq#&7od&d7#X&yyt_U4!7e(S<}_kRS4ty z(e1Zu?15P#^lgFGbD?U-zf5~gygP}bLVQuF&@v*Q^08~%Xe{Jg76k_XJQalOM@=c= zfWejWaYZd9`^d;hO{D5Wj8=&wwDh@_KdN+cCSULV;k(VIKSE30j6i8O2mOst{ll5^ z9zu|J;G`x#2LVnc&M`u{Sik8ZevKz$<^e7Vw^2e8N)LtLT#Q(FsukKuw~1>em(Vrz z-jocMH!vW%pM!WCGlY^plb8L|>!SAPvHk^E9eVbC02(duP?g1>+GORYZf!$@1rggD zmv$!AcJE-Q>L+a+`oUq8aHLS{)e{P+M7eV`I``($1`-d3F7}{W>$W)54x2GYQSBjU zrn8TSW8i&>TlsYbrKe-X`mJPQPNQ94b8`jqr^r0Xn{c~e-uDc{9ugAgYxetfj(ZS} zfAwRt+Ko@joW?YX#tdn2QO3TmsyhipN!*1>2AG zoR+kV{g%4bDUe7kbi?Xs>zp+$yRlP-~f z#+jL)#E%juDm4uB__25aBgYzBz5Kj7#;WmE2r^~;(%kI>bS#0apOQ3o&iD%-Zf$3Tc>zOerE zjO;%`joLc!@%q7q%I&qyrtl70c*HxS$smJ>&<;2P!)NtZ$lY7N6q#!E&Ax^Q1fT;P z0c72x6Se2|gOrH`-C;|(W-CfLzxN<=%dqc-3`vsL!04D{6Xj=4nZ5X3;{prP48=Ny zAGtwfWe~ZIpwF47XRbd}h}d(eh{)U$_+Tv1bK%wk`oD#MzePNQP2*>90TKI2_8t5u zWF;<;^hhkFe@02)i!t2hI`%e#_7jooO^i>Wb<(Uml#rC<4Fp|TS^4eTNBZn|0lX1W zjNy|^-;6;&exHzBk8<2nx+T|p%l3LKS(7aSG>_E8t5!0L7l_L&L(vY;<%9c6e^Z`4 zHXAYT&bc8LJ9W^Y=H!l+C;J|JTw3u!8;yD{#&dO%II7A2_ZB}sl|p>NpkW8_W#?mr zQ6>_+0a61;)Q;lN06_LQV*U(4)1*xz^U9&SEuA5}9?^i}6%E~{KY%FI(AFN-kSWKe zP3EbPYm381TgRO2R>1y$fS?`PW3osOCGA{0Uv~k~8xlR)iGR%@N2^M4_rN}o&Ex~S zE}Doc&MPqkulFP69eHmgV2UJ;!^TnOTwGY#DKqnti>5zY*MYIcxkqMgWMsr8iRsbE zlk8B%Uh{Ks_+3FXDKYMnHpd);t74p5RJ}@kNm>aB(AhpA@+;}Z($w0jEnKsB_Gr*Q z{F$3a47NCkOPqW=`NHnI%xwcd_I9afs&wg$?lDY$LwrGK-|E~Oj5TW>*FxJtrYRAQ zU2-T;98m%((58_XAIxR2tLD6N08_aJrwiXG;h-UUHPD(3WN+X4c0n=cY^oIm&YHwH zL`Fp1Rwu*hkUaB|xTUg7|C~aHgGVj3tSpxBVrUcl{qNz<+i)XdP|^KSL*ZruDDW*O8f7 z>GAs4vOkgI0}kpfVO>ee4@`HUI!85n>VB;d_c$=!Bq4gN|7Bp9N*Zh-zINDS-w)Zt z@m*`{F|&;j`8a{me!I78M~rwE6cil)E@%$tUCEfQp>45Tt3+dzV55zZ_@)v_{8&P2 z$^y5mez)N8)wL@hzg3>&8k9ZWdu#?p!X-c?-#nHkNJZDq*CwZ7*?62l)N&IzWdNAm&;nS+d8~tqG}KI{*#O% z9NvdIgt|?Z*dF(O%yhB4xD+6`ku}he{n8Vq_{t#ZH(Wg7uOnqPLBl!v$B24?X53PW zUcEDmy`}#C^&zgi8*ir``P-wEW@OsleZL|4Kzcd~%jr#RO^y;BqkY#eP2Nzkd1dF+ z?D8Y9^@WMN#22q?2M#~|% z+1QA*_zO4rp2uCEFPe6hHx#jQ=kRNq9-_~rxO_uzGG@u5MJ2h9OJ{QbB#aT{Z;o3VeVO3`X@GPZ1M@qbKHRMbf!RURNdX(-%#1?-7*BMxVpd^ zxhALJpvd_}Q=gAjL$m=XdZym7`70c~9r&zKW&05R9gX=QCx^N0TkL7BQJcfHR>0^F zf>s#>W|$1<27$i+8}*^VmeN?2CZUZCrfBRzFco~jv2~k=t7_l?1Z@8qe6fHDfJr7& zAqLlD)3U$9-Gi5^?NF9D4*Ku@gDPSg1(VA=O1JQk()9G7UNn1P+sDoMHw^-LL_3lXH;3{1suH<|B7*GM4^?YGQI6b@o?>U9S=A*? zv-d8GsRABEM%>OV4~VHEvDq>wngzOxUP(r3DRly5y?)(R@~bxSH8Fwj-yZ@v57)Ts zN?h_ZnV*IV`Rjwjfi*+MW*u;)fUqd$R8Dz$>HYo9#N&z2PC`$#PFpRRIETJAB;~SEr2Ig(dvzb<{CLuMh9%Y!qWN z>kDoAsskY2|5}z^rtU`#gU#L5Z#ZcOPPC1@N+P<{EsE|=!HTBK3QYqXSy=g2N;0;;VT&Ey|iHTc>C6%y^{oB z2nofbT*s$&xkmL)I4$O`83(ZpHM$R-(^6%lh4r~@)$6cts0HH<&x3j?=IwYSNJR?aZigWw30js5_oIRul?beUvhw6!Noj{OptL~1c7+Rsf4&P48 z|5}(P5lsSH9^7*K>uYuZx;efuItWgtopwc?A!smu0Dfgfx(Zyf&eB`rh*u}X1jK`Q zsM@Ws4^2x#$tMjY{lftUEGq&ff(XKA0`QBEp?btU-}))qI9$|9V3wpEK)DG)C>bWw z)Z5!@9TM978%txFTNgm)Qjc7)MpQ zNVWm;73zh$onz3MB8P_$Sp}$)z-WCfwCBfDOyo?bQs9n|KXSWGZ?6&rooK7`b;b8= z4MC?;af-MYNEuwrxBa=43|%Jx~MUc_o_}jJxD@;+%hVb$641 zB(&hhZ;==9W@Hlfn2tr2cHW00-1onCweSMt*8Thj~I4>{ZH+2@iGnGG$b7wmvZ^8uk%bUK}mLUQs2p-zCNKk+l}vJ#3ZE2325Wy_oMJ8wwjx1LRHi+*eK-wL8b6ctVei zdmHbe`Uz(c-+ML5iqArWo;>jjl9F_^73EMTpB8L>^6MJ`FYbHk! zxEg3Zm5WL}R&F<(SaZYL$HX4viNI|4ZQvnMsV8E?PTK!#Lydi@=_Bt6EtA~64+y0e z$RP_V9~k)bUy1*sLe-R%sy!NKKz#k@XHU|dC*&g`c8CP?+SwztDrpol;}1NRt6sa) zJZ`w5WFqdLeP%Xblzoe7%L(m&!HLH?=sMPYUR|c8&!)&3dG=$SV60+A&wZz6{n3K` zNly-Z^`!A>xWj(%jnNP{jh_jFu~S`1j@*Z12~GuPhu$qE_s?aVxU?E}HLmMxLH3Y< z;ZX6QSE8`7Qb6u{IT_0QhkYso=4pR729;aYU4Q4mL(}y2aLP+-s@l_aA141ymfOV- zQJ0OFqy;~xH>UB=RlK71jLI@@dGgN3Vy)nS7U7Ks{r5e}hn(t`PyT&zLh6R*M&a1E z-Em2>dhwmH)>?M6u?_Q^`oH|{da)y(;oCpwrN6%{WqLDx4*K-4*67-2R>W*dwP;5AiYps*N$h$vA9QnCT+LRe!CLW&^MLS z%KOaScs;|Mw^NdU9 zDI6cjcP|$%W$`~X{2S$W)Mxj@TH|URe>=%v=B|HTB**`4ztt2TW=M66e}9Qsz!?oT zxr9(%-uw)zVPHi+&^6uD{q!Y&uanW#`A@I*t9YlF>r+$SInH<{)ye@4wIlK_E&;_@ ztUXI3!CZR=#b! zMSfS*-X!6>QX!$%XvN1*Ln|lph;Ms{E7SH?iKRe`pZ?iH82!2A?4hEPc88e~22*qW zXVX@m))8Wy&(F4~?hYEJsyGmM{Oz9oleBzP!WjyDcg#49zeJnsV{cpb@QnL2n(lf$ z4$N2X-i(f;ql-{b))ZJz9mx;9$jdzXn(6(;gE>vVPM%}?Q1U7xy$&h%XO)#{0V$jy z5%o+jw)jXiIJLnO*QpknZ)$B48s@da8EKwyD{7IqBd0tth{ec!UtE}y_Z@~#hd>ly z#o6;#uf{eRbqjZQI=he33yWKQd{phf+J36L0Qz(H2JuDPd|vlPsl6S)v)?|R8e7b| zEH=iu$+ba^)BIZZVW5Y8!^@*xzBi9fDxQx{RTX*kCnj8T>yl%&AeQbioQr#WxcwFI zi5&x7+Vq7DR5uHD1g_5MEQ%Iw^^r?h$qxI&ssH7?CZ(!-U)+qq_7GrA${AwJmo5AQ zxgPaxx7=~h{aVJS_)rIK6()!BQ>Rp3y!6uDk(tJ?*&y>`qZt1w>3iR~$NvQ`zsZie z{ikgDW~oVRMB|2?N2!XPnC9AcFD_||C}N>|4YkY^Dew)nc{*^{7$nhN*poa1u+Pb8 zp4*%%HfVTNc)puT|G{t!N=y#lX4{kUo{hDR7M@F*9ep)J_O@q#rh6MT@*R8IGn>g* zquS8-?Cl%tr+RzaGra6fLT7T)UP|Ds*47oPvg4IX)fC>ZX=t&#tKe9_vFI$%x9jv$ z^E558i-9pY4%>wTTBdHL?RS4;UD7wV>tU@;)NLx+pI+Zi-_{hHg>rLBy!&Ku&X;Rr zA2~-eg~T@q#6~$8Zc5%J)++kw)GgVq2Yz4J*DYwqL32E=(eA>wP+=gN0s(5$IbUS+s5A# z_}-$m#Ubd?cslnul@zPfG4HyAS8Z*Tb=o|piXGg(7yk958e=USnW5q|sn5p|)7x3w zSN~e)nIFyF?D4XHyQnq{8@KjQP<|Y&|0|)FUieoeZdT~azUmZ-;7d-it~En?JdHD> zdi9epc^?erS^EXZ^O)F&zCF<{xxMR(_p>Ci=MShZPr365-J_B?%!z$PG z<;!rh5s%a!R3t7|lHKLJcy64Fw!ij+`@WGL z?H<0|Xv41k19o;RVH7m(1Cn8RIJ@SHh^-swl}Fx*6jg5W30IP{s8(qh zl$aaPGZHx-^7t`cwBopngn63K($BAM=c`B0b=@kR7tQjEJIj6g){Z@8lD#8MBN{*A zl-;)1jNSD3p*N&Ax2Mc=yFv2zm9JhACQ}Td5-lqHK@u$sNA4whPpq)PN?77(#RpT}s_sg~3hS@%?&Kw~W$;T-yyQb@Dd%G$4P9}Q^ z@D@kVYOqD1yZQ0IZ)-1L=6#<+JtW+aV*Reqi4oNrcZVeDnoAfA|0RZ2Pt!@6Y;yMK zDp1?Ig-rOoEVjX`aYUx&^BMjis>*?ek*a$V+RHLo^H!HS-$_P@%j$Y}+&S*}TxDB^ zkz0EMs^r%8b|fFBpkY7B``hZs2D|SZK}yhwz0co!ruB?g$984Om}95t00Uz91$GS3 z958h6c>K@4RGzNcx(v)=Y5P10J@pE=4Ms6M$Vwzyd>?%y>NE_eBCSMC!Itb)|O;)9>nB zVO`YylUpmoQ-h-%5%8_Aj&ItBsGn%T^mLCOfp_yg^G6GrDXR`3Yih0i@$ialM~}B^ zhj&(b`jb?D=E3_WY{QbOLoc~^zwPrhRHjdv9+Pii*U>q4?8K8rtM0=PV-V*Eh`ZQW z85f7fWAo<>43?q{rdk$$RoQdf9yf^m&gmSoYxvelkF^=58vCP0Z>*QJ!WO3)imL0@ zHZ3t2t$fzlNnM5WAO{J|G%%z;B<3kB%-(RUW|{$IK+TY4bh~xz;~PHN4iOFMBuf{xB(tK%&EFNY z`h|8oVZMvIxc{G$%gfSrZC;r#j?3--XE+o})pF`f3m1RhSA`98H&wF(q7!a3a*R>d zP-ZkAUk>1>JRn=I-Awyp^p=Q_sq*^sSwUZa{N)iHcM{Ra3ouSkv=Hh1%yi=+Z<}Sp z!qPeFnnMoTQn+r11w^J?tl{}(ai^%>p+x$eCwG$=`>X5o$}fufo7SE&pQL|f$S$g^ z*h!huCi~=~DgUbt7f#AH>dI1Ov{k4wEPs;aZdo3LMesYeCKsE+y9C&mp z!HNsi9y(ae5jhK50C)g8OO_gYcxaY9n>Bo_yFm?U>m1w1)CcGRqAiTqQN4Am$R#gS z={u1Efu-BgyfNKe9}P1Q7#Qy%RSP{M2rZo{nqv^H%=GtyJLk`|LO=?QGhR5Z_)`e= z-gUW*8 zyIR<_Cm4#$OvnB5&>6Z406Qdc7TPmn=Pz;`36G6+KU40B4OSM_xyv_uTCB{Tz@SL; zJi7t8|52DJy}1HpW0?}Kt%Ziy0gjyfkQvs$WkXm0XBL3royyV0E4BxzXpdRRz`+6+5xqPV4b1-;Dp-^8^dUE zu(>`WIUbU01LgVcifd@Mah-QPHmuaY61)DStE)^SwD4D|C@ASqt^0CD#il^Dz02lR z31#iE$Zc6Y-{;e_LuL3C?CwsQv)w%R?Zo-cgn{-OLIX`fl9ai;dY3|-3*z4CQ5`x> zp~xE^YIvo2?&95r8o^V#QZn&$_kN90hzTwotm|2o;-SoFNSG~>%rh5`NZe<lx{nJ;J0SSN-C6egji~PRqIcZ*N@^L-nj%=Hz0Ax0gIkI@=ka%JwIl zVY77Y$vqn9oYwE&t74sJ=`O79mgOJqPJMaQ%aV0eU%E(_mM-`Z^I?TP^_G7Z8B#`$ zP~YG3vH0EDc6|yrd7U>ClrhhFi@v7L-jjSWv4uOE=bV?NZziXQLjz+XqvknpT-+`T z)?0g2Yr0i`=QZ-~KdCBwS#AEIRGZ&H1;_3y-aK30_a52M7<5T6B>o}IJL)7A{)3Qx2Ra579;X^canld>=Y09ND(BT-d;H6~q^z7Cc6k++L#a7@-g ziHq8RZYQ$429X3s;@Zu?DI>JdUBo(~x|>YY(cgl3IMouL&(1m^=m~BOBFO+WdaIi0 zYbr#|(;!#557MfDY<%&;9~O_b0i3=>?GI5acTEERl1#>ius8&}B(zC(o%vx*&j7xG z<<)~DLZjAFDiYWhI8-fteN_oAG99?aD0!p;5*6s+n3qvcwSq9K5vq5N6COt~ zxo-gLS^D{*QyOg8R=0;xO@7GC%9;^8Z<#OBY;9CVsLpr=0=ol4BM$eRhj;pHSpCkycr(s>TU zm_4P$9tqSu?-DI$-`w0kWQH7fGKD4l^uFm;b&Ep~F95|U=eG!)8XfHiLq1H#xB%Jn zTPbhk^I?WKhwMbCb4p)6O;5+M0vQQ$<$8SMC|7HJFQHaG9b`*J-{I5qz({=za)71f zvLQ**Bz9S)4M;LUwXrKqh5OMVIg7&|E*@M5Kpp#|^0G_~y_h^v$}(_C5Q07zKX|r8 z-;beH0*ZYJViG{ed+Fbn(>_%LR1N>~_dWpLF{WZ<)YE)@a)jPtWwm0HN56T~$3UI2 zZ^p<}Ex9rAhcb*!ap$8`VJZTp^FUXJT7Oo0G@m&5Fr8UHI&n7V;K#~CEQ_v3cYb49 zqOn>?KjhP+x+if^$NZKs)FZf*s7IaW6|L-xuOryPS8Cv1{4&&S7mR$#%U zExr)KM)oF)Qsrbrhxf3%uvc_O{}Wz5$}*J?M--}r+AMD$RT}Q0K9x#IDIvdeKAgGe z@q&gJd(4xi!04>R@7}hS%!f>p7DF~G+}P1PtV$PS6MEClX)VHbPL#yd5*u6!W+>GZpKwg009$m{pAjS9Zs z-+oi-+YkNzJ>o(!kNn>%DQ>1eUve^!TB$rBM^AD8$K>xt9L=82Ty@sJ*?+Pbvhod` z*%6Rx>OvQD_)cV6yiUe~h|{aHrTaR%v&T**-yKBeInIgS&?c}_A>1&Q{J0dy8qc9=doNm zbB^)erw7ZIJ$IfQz08xaPFTgdeY17^U&)*YbeB!~(|#8-yfR#GCM=mn>1G;n*EZMJ z_<8g48xH@onttla9Z~iX?zg;__sDU4pjU4>^d}>Vp7Wu$o450q>yLgX7LKfAcU@o9 zM;GR)xbBd0g{p|;1kVqN@7oPcFJS5fcbx@(|gyFc0qL%ha zmk|}GN4_fdDl>A)?~WKIjmkt(SKnx96Oa{He=1lmUhl&R&6*<$Nn_JOH%ey4w*3uJ z?u_4>PJOm{Ltw||0^Uofyx!m5sLVa^>y5Y1px;7`@L;{xJ|hEhwl$%Rmzu(+X{v|p ziC_dUeC_3Ms5-89Pbzr|_z(h&sAb}!Hl9Smhs|pF&*Qh3p>D^fwO{Rmjt(phFi-q| z6N4DZE+f@G?v%DJfOleBDMZ{98j5eQ5yJs9tMxe-D!A)?<*{U!!zjixt(U>YxM4Du z;GRA8;-=gizyM$?Ble1-g+=RS@D-d{)!Szt6 zBD$jkZBnt+%JFMNv`PHe27Ujw5CuH+YqMgzk7qo`Z@^AQh->1ng5fCqjG*hRs_`%P z<2%yRx3d^U_f4gN)PN7*AvSIj1C8>qwysWEvqzOb$opSb+wIv#^(3f3UcTz!#M@s}S8nwj<$`Dtj0`+MCt!=8YI{;R6fA>-#J%pc z1{(+*qWfAT16(`qyCK{Jv?OGV&-lNqoP1`m;<0u`P<~NGKDZSUCAi)1C@VV>Et{t1 zV6!*RCT2P2ASIIXCiy)=%M#@?5y>J}4l2mG#WUE5Nm48cNP!4(c~toXVN8`jsP1hn z`cJ|%S4KuPaY>Bc_Z-&)*x_XK-?jPQ5q3lswbHVE$5iM_*ejQJ6phwa&prRt?uJ-& zBgC7ikuh9sH7epmo)So8BHAiMtt>1eauE}h0T`WW5!sbLa7@Y5N6B?SNVKDD{Z@dr}xy5gaj-Sdku)L6#A@iHw zj{W{>u|9wMR?mK69VtZ7sT0}Av+3#koV9g!{SRgz?V?{xYo3y0{&;_Xo91rq;(pP} zR>Lykz*{;$+^7%4J-d?Y=PYo)sY%MZt**C&vF(O0>*0yKi4OhKH-)4Gx|FYd3y9&- zQ+hymvA;&aMJONQh4ME)uY`q1j;UmRGt}y+A^|Fgh zvs}e}2cl8}tXQ`nZF%S3vulTr@Al#YyE3Xr8QC+kubiz-?v=aTUpN8lR=<6%m!2J4 z{ByP?`hl0GV*S{Qq#90}i{Ap{wHJpdqbD}TN>ACP0^hD;b?D9^p zh}fC)3$l*=zZ;(ZThpFH_gGO`nQD3$(na8Xb4%vE(Q0gO+&%Ofx; zp%*y=md;u7GJ_xhYnKW2(Lc;cM*o7S&ClT3x*&AvPPaq-3ngLPmAE?5CYLG~k&B57ll@PPi0#dHsz4`L*^#V&|cL#Z-j) zImyROsNgFiu7zSC#rnX3-gcf0)4R_!X{jioQVg-OEGjlt5%^Npy5)Cn5OLjk2+FdU zH|uo+`bThlAs#=VtP(jS;57lv-3j65NKl_mMo|ADVwB90oMj(Dj1A5#2P#Hhv^M?3BPo{2R9OXA1GH z6U%G=9K)t-kGHKS8s>rBkuAjtrbbx*-&BE@`?>e+`5C)(0<6~$Cn8Xc3ov8)2yN=S zckl9a--{(0;JeU^K=dafCg$1@;&g}`-9KQNjW6Jrlie9KROOe}=2q>yL;4YjOfWiG z8iCah?B}n$qwd_3=X&(-dYcb_k4^p)k}3tR>Ur$%{Q2eoPqnZd>4t%KYP zJMxDHfH{I}c{N9}z!0H;K42}WB*ATQb*-p)xWpXYJm^LX>RSa`ML@C`9*!?A2I3+R zJhB=yp@J_3n&z{rg~B<9rE@h$gEAw+xDu}YHl9A#8L^LWCBAaqxz_ngSyoIe24^7h zD(Xb*gGkOvZNF!Zs?d3Dw5I}o9U(S3%7hL4nW|64y1v3h$m z8;gJY_G;FRbgS3WAI=c*2$4m87!8B>dSOAqmm=Pe8~C2ODPAh@3Y#v=n858&N$PoY z{ME^w9bu>U)X=FV>8S`TzHtVL0iuo08*)nl4g+;3hN{d#tHQ0Kz9SKGC!L9jDYRSn zzmteWEUZoJu{r2qU}Dln*>;3}JgoI(2K5IYX6|1bgHW7MltN4gJofEeEr=|5AHha> zWLcXl1;W)pzzv=Y7!7|I2I5?p_3N+iA=~QOQ>Tu6;`y)6_CMle@H$g{vr1{5!5wKZ zPbio@eI*=@s0c`2GZp;#`0nfD*XHD8WyK+ig~k-C;Z{Kqljzs`fXWi8s&;`jzhlQ- zQLp7y54wAs5$;YNYiPQFr)J@P>13)y4HZ-oqgJzLVCcbpIjocIB(eyZ-LD-xO&ov>0dm9pJXjaH8~mJ8 z)uUf)Elf-cQ zfPlxp{# z($|Re9}^QW#IC^>)Y?nnOoyjad}h8MrIy5uj~WyjRC(lMR8#usuY{*kHP zp>v*o+}z8CUzh*0(Kqe-cRcNf7JFPawK?OLqk@0+0;3Cgmg#(?Y8Z1JvsVs8Z{_o> zaZ*-=v`%c%nei`U|8Ir)bLSuISYb%m^&_z>?os8_0hKbHj{R-*y~&QbY42Rv9Bw@S zdHOVCyrE9N4bJFaIdmeGeFYPKS1o$>6!Cl)D-6{ct_ zajIm9*X5-#oQd^5SFwRZrC7M;y=YD4>DUV6=!m$a>ysOpmQMS0=;al#N76N!3VBBN z(p~(rAE(W&2Y7CmEgneC(mv^P-se_*bG-aok$+j2N`A?^-vqt~i`Mo9s`_G?o zjK}D;A_|2W?b0?CJ6hD~=YQqgE`ErXwnN(R_4`dc?E0U$Z+3l`8uvE2tGMU$(}XfQ zv*To%hlM z^-UCPjkk{LSH#Eq?fU>9CfPUrB@0J0p)Ca z4^rFd96#_Qg5!eHFy+o!GoP7tS9>`e5J19+6-sU3*JwDRfiy;3%0Vatfwpt{obs++ zE{Zp&Myhw9E6NbVFZQbi*kH0%zyLwd~MjB|jNKXklvlHStXQ}drX%%2;3E&%IwtYy> zA<^K_DT4O)8yp=H%vQI??dHZ;1pyaR-GFRwgC=Xo+F_dt$iB&cHz@Sn21AWdP zWDmXZ*Eg^dNAQZJ>aBJ9JKuX?uJIO`pRkUO?D^^jQ#ImDfXpD+`jOsE>+cWhb;u)0 z2qL@y&?}aM`ZpTJn=XC{261rK1f6%~{E(;8TbLihK<1u@5sWS2SUUQ!j!)t(fJ9UH zL&D3NxQAmOry67FxzzRY9*pI|p_@2WtJQ4tZ!|J0v4tW`pJtdA;uRg1=uh@mKEci;mJ)J%}vc<(2`ILAHGp?hsaziK1o9;A=(5lmg z2nafK9kLb{Kz%I}KP%vy@)Lrw!&j>5F_U-HBOI0Zl`Mq-Tt$C5z|TKcr> z*B39+CMb2}l36HN_=vQEQ6Z*Io4_T`yvA)2+JMLMzb^m!^$Pw49AX+Fv@DB(Q6QRW>sPxeS;l!5Y8nKKFtDDJ)^hPQ2D&ctBhUz@wzm zA=cnz(u`E7ofPxh!r61%0B^$x5n@L0@&D*iH{aT&9T2ed{}}Tuk&sY6Mkb|;i%U4$ z85m{a%>|v*2%m|U0QZ5=frC9NbEV)d7q&e^Ay__Bkx4KI0;$>Fb}0&jDq z5)ch6aZ`dJBn*}Q0P|RvkKE@&ypRyK4%;TUMh(@6sl#|l!#XA?TS+|<-PmB&v-qO;-d=`Hicj*#wZt@)} zKaVaDHckIZe=^aU`K9y`<* zxE3@(uyPd8Ja+??W0Mo97jN9)!TTFHDgtZ5{|g}__g3->ao~KC`6f}Ao-FzH_K0h_ zRb^#tI;>>-#?rcBCYdCl`W%9>cgMG6mUM!My%#$#+@=NwYP&sq!?FXEP15iR^$g?~ z&P@J11MlK&W=@kdeOMMhG@_qAH}Gbb>cC&%MX_N;LAkr zcF6qyJ{R+h9PD7QWklld(6h4v>RHq0z*WUkBj%$R+CrFuj0G$xH%yE?mW75r&{xcc59ENo>f1R_GBd_!6TlQOYc=&1 z3yo3oI9yyls!mE&SMQ3{W7AH4vaN`>MDmS#cAk{5J+1zgxFZ{y!r03n2-{H}@Gco^ zlB&{_Y`#%7Y3dZmRd&>gQb%OJG{+#XPFe9#O`0snKNH;==8a`k*Uxcis&1sSak7hx zk6pURAb8c%SEe#OjUva~D*v{7!X>|^s-B}&hK&j@6Qpf~?YB>iuHvLY+g}xgCV63#XUCBrxL~|W&$XqXL4V0 zxXTyW%Ti79h^KRHEc2$YDYjbBzN`N-vkviuK1s#_su}CHnlF0(B9T@2KP%^rzj=&GdPM48#iG+x|)>rne|Qp{F(0 z(|2a3=cSnvz4+g@q`WNNCWlf&w*Ct0_Y*0AQO~Y?&!?@iiCMCZ3yy!Ol zq_$U@^6tfJ!hHJqDGA+v+Vo8N7O!aPlOqea?;S0oyYiw{LY9_J|L|e$Kj&l$xGamC z7W_B9+z2E5g0lvZWBv5HX+Iv{(H)uNlXo)X(z(!BCf}rN75loXY^R>*iP*7vsq4&i zW^HuMJZv?U+nvH!REr{|V(QX)1CrQ=woW(*i4Kcr-Trqmckzmdj()0${r#l7Hy+x$ zJF(LXCUwmpVx2iImLtX*v6{iUaGcQ<4vs;Oc&0u>Gb^dO8XHmSnzc= zvY1~Ynw?_u8n?@i>Y|ClBCUlJ6)pne9|fG6rr7&rm}q17*?Y+{35yra%B#>m7&4@x zrm5a56K0pgw%0@K!w$V^Rf&qVhkDIHt{X1|M)e1==_g0d>WbwW99WWy;HPhRVneH7 zP&vKzc*?aKl*gz>a`kV!eOoi*5zMvU#BN1@#4G3Dd0Wvaspy6EWcDMoqCXJvvD57p zs|E+W>AA8xTj6Ki@vbIdMD?M)lOGJN_Wb?v0>?Hh2S;AO<`bxwQSRkP{im8Vo#=AKXB>q;&%kK=t*R3Yieb#)@dN`I6Eib7*nj0+ zz^t1X+xNw=LZl9JrCW$)jKNk_hI5}7ijb%v`x1a0*a*+A&Q_Bk5|lVSEbG@_{VPlO_@g7KO7-4xF-ReE4w zX}aMC3DI*EL7jzq(Sc5Orw!I`hC<_ols@HtS1 z=fl7bMx5;=iUw{0v509wErB8(#)B~|ny;d1nU*mM3m>4J^(~+T6g4tRe4yphnJ%~J zBK;dvb;xyywVY_XB5nY&Fh^9;g-!f)MTZXON~Iwz5In+?YH4WQz58T{i#GtBTq)x3 z)#T)a@k)qgqmz6WY|hS}RV7p)uz!Si?zGPGIsM`6R^ly;pC@MvQCPnm_ZN;A1%`!E zq1>s73TM>mH}C$y)Yt&2Kd~U*f32FAa3s%i;3@GRJ>^MqO`d9)RmL$yh^S zIV)pMq8#ubI-tQqDmZE}px3woEU3xH$Q2QH5q!UpFL&YEhJR4=yG+NrYd5)R&G06K z(yx!>*w@g8GHHJ$t{b6ABe%?RX&^hTRR&oZTc#{}IxdHoh0wuJV)Eev>y5*+ANv z>F8|LQQGVdZxDL=J}Fn~NZ3Lju_pyJAqIweV+$b*3uL4U62@K}^+be8euRSfv*7r@ z!W{11bo0;e$(b34>+a;80+*xnP=wiU4dO-sZib|pv zyXb#(a@Fulp5*kSfSH zCO%0hGl?R*Q*$+eq-X?9AHa5Dn8$`q37#29J643mQ$+6j51XD0ulHLy`KQY)Nf5B&@`|VreQ-m&k z$JG0$yIC~mSyB0s#3-<)I3i*9x)VN}*g{(2C2N5c3?f1rwy|f5ZiG({+2FAgsQ5(N zVq<~|)qS|a3$<^R`u(kDHyZKIz>>U&wz>WLri-~*9(dq;N|!5#?fh%^C9XL0f&HM$uO zZ?GX^{p`W*7j#9?$P)W2YPyYxOJPHmjGV%mg#J+_4HXDk#+8;aYUiNL7Ft#w)GaIx+G$%@DKfB+v7zqr)AD&g{#( za5+LC#~nE7D8h9~((FdI*5R1|^Jd52t;*=*h+8^@^$2QwklX`M>@0|{#Pj1-XXn8E zCpbo1z}R|;mz#Jk4?=AVTnM3J9~VN-f}KECJ8^cgD7=F2MS@#M=*&ld70d;}2vdMP z3fT@a%f?wsPIH8p?EiGGet>?wR?6sv_qHgD{bzfWtu61x&~WA^`YCOXbYhd2V3KmC z|CT7)cc1g1eq&H_%rcuFe+Io|VUc(2E5@Kle~MP0vb1c}Rn+oGT;+OAy^-geXgtmS z#CH#COj*CB--@ewwUK`L^m(r}wFK53Q|W8J+m3JEC|oV0R4*hRC~Ugy&)M=oyDefl zo6WwSxo7?nYgCCd=hNu-YpuVt7$xbi8QpvG*lLqf=<-d?lI_AiGtLKMdNpZV?t5Q* z)uA}fVRf~X!8OVFb>FeS?Ygg|=;@o^JGV?JWzz`nmk=>wK?> zv4;B=Y2hly8^05ztWxS^GMU+Vj!<2&qcor!y(J%QR3gddPT8p!?6To-(*ru4L61Ls(+eRTpErTj)cX55;6IITDjWwGcgRvc%ZSW`51a^;b5Y#JON=suRS-T3HJ z)!fCi�^{?V}`1=apP$i*1<|x3Ib(@C8`=Zct2aHbnYlemkd+QGxIn2xdx0}W zmr>F5Y0X`V@}q{Aihmo)S5c;Vgr4DieO@wy`gNQI<=jDM`v!HE4lW&9wtPAU>xYy+ z_o4=8N|rTQ^7P$}Ikl}k|415N-ppm^Ir8@E6~>bCgJd*m9dopf}I+AE7oznY!vee(v5{x+uzuS~f9KIY`{KjPc8 zWJ52qcUDD*>hM;Meo^XNhwLC9ecBRtRyX5}Y2olvu8r}DcA`W;bhb>VdA(5}^bW(Smbq*0*#VI&-R2ow(&gXJE(-X5F0ck(XR{HQa0Q$al<=NznhV`SSg zMF!UgCqn++y9#iQLfwukCgZzhXq_W5%mR~pFLD#n`lH_7NJlNG7D&(+pv~6y_A9rg zJl`xjQb6egUv)yF1a2I-b#rffLi^{B$esfDNB3x;9p@&S0o&8}AaOl=0o!b0>!WSS0yD;R&$066H3s{?rtLx7{=MXf3uK+X~gFl zpB#A@9I`D}ik+W-?#CWP$bCeM9$fw`dX=(q*Q$)1&|JT<|Tp0BOPZ30XN zDHfo@yl&Kr@gG>gLC&eHT*%lFZBc05DcMs$F@N%Hq+W@6;V^=NiGmp=y`@xapEt1V zsmLifV4Ij&o%Ptx%>W>TEbk5|t${f2|b z>tj2K_4dG#fK4zpE2{6r2xJAY5^zsMkB~DK*grlUmqRLb6pQSZQ4|PtQxPB`S@?~( zyJxVJLF4FJwdFrCo9WSU%fQR)UOkJqNAj+toh?{jz@y!8tOtG@OvSIqcm-E=^(uLQ zFnoE3lO0ZFUqgK{!9g|m1>{1wNf;H5bT9tF37=t6$o$bbO&{Pb+N=V!K#?K+;30!& zjQFoFLfmJV>0(HJ8~&vWNt0Ajkrv5$QS0lg5$h+2DH|&55xE=gN_@~p#4<0Ue9LrH zz(PP(o(^w3t9=)``VS!iH}auQ_92wug!ZAK;jq5DWJP~uree~2#^LeryGMylKj@4u z#B%e~C*k&T<}MPz1~72Cjro=r?*pQ5seX5giTg4w%1RXhr01rlrRk!kL3dw>9I6l@ zbn`tJ3u&;T8s33n7$r{sNr13UK(aw%|2p&;&l9Xqcq-%C3cer38bXJTw4#yYPd7at zdW{tp^lx?4r(70|F?lbrM7!6x{F0H_Quu#SpD3k+9YY7#D-a|_QiX}pC9-!c+}$M* zFe^P8v919J7685&;62AI%@}nToISpL`4a3gj*(F(USfFToR^#TsU8XP|L`?1EhVM# z@1NhiN}eeeA`eEF)WAe9gxyM}gwB>o({J9L zywJfxyRccljdbPZ<-{2e+oBVsZvYU+GL|-)vxb0eKL7pw;livQZAd81fQ6&Hnr;00 zyHd`f79XV={EY|J)hh<9M!(T<|JR#IRhvhtS9u9pHbiUORv9aaXVy~ zCVs5rRtr2LAFSr#(em@}#V!bW4)F{@ANKhmw(iK-b2bf0zUQ%oX8QIr zUyQ*7O{2yjZ)<>xfQR@&uQqN_O9;aZ3(9%dZBTdB!VVzr4Fr2};Qw%Nx$`^x9wiF| z1}=Chakf;={Kjp^UHpShzpbOAZ%qjYksj>zNjZ$S62eMg1-Vk#{b4Yp2VnYm^qx;G7I`J)BV2xFs*xNZRj;)nRl5w(7Vst{P4w8WBrEvFI!qf zOw#J_SO~87td93;^m|vjJRZXCOnM8hRKCm{I;Dbw0<@QV@wz)cI=Ixj8P4G1>^j%o z>lDAitd*$Kuv@V4^S8qZpP2l~^K4wnTwhlxPduu?_R4p)J9@gPS2RiI*thq+oF5vS z=-)CJe0JU)J%-{gC5+wFQpcCue{@!OoR8nII-72dqFb`~W;>(P!7MTH^}hn*tgZ$b zF(o`4w4W%J`_HqbRVqcWc2F%k9y9Y2O;uwUofbZ2EONXw@v4#QrH0x~fBOH{?F#=p z`YVt2g2bbQ(>{@LyW>-yM=@+wY(Zp$a?-{GMcVg}`g-~;1RjSgPD22S)FfLYC{kn3xom!6^Hf>DNNr;>;TZ+)4%OY?dK!^tz>Pdb_G&W^oWHDF1- zOwY)j&Zc*3IbiIFE$!D1UOF0n{TcC3k-2lcvqBH6qzkHft!pW3407CUSf$@>@n3O} zq+nlOx*#I%WX$>FZ=cSuQ33IUYS%g6Xj9>$qCnf>5NmCE265(Ool>b-W}|yyZTtR| z8WwNyZfej>$c?OgeCbBGBz-ii*qDQ0?#JO{GIA$gRx4A9^6aEdFrkXd5>=98v0Q5S ztMrVcv09btr+208m>o|7)%7FPlGQrz?{l(p&^!<{qqmLv!FIEen``5KUB-SU^XA2> z>}=LS-p|bm-2Yuvw)&mF@9t!Ibly|iy_@qGYgMv;#@PP%$n#^f%8nc{E7r!hR)+ z?$@Rbq8<^;IQL;+pPu9NChL_MKsd!8#%{pb_zX<{J zN%#t<0W!mcme}Isz{14(`i`UUPeamMZ0nZL1dP-JLUt!c90-`cefxH2b^->T&e-}9 zZxi8+#1f33bvvB|5w;qiTZ-5mQc;5lLEOJ^gi;WtG?ZKpk9aW1_EZAH6MF)&AFT`g z0~6fMIhOS4?JaOF!~7Uuv;N*;H$py1gjO_^}0|W5ZNJ?sg4!~S1%}6W{`vmOQY&&#(@L#vcJ8CZKb$8CZ25IQ8^BF}FR3 zPM(5f@?a~hoUmQ1E1I#R(DDctT`Sx2HJIMnyP^5o`(f@T|qW z5cXcJ9JwGT=*#mAuYUU6n*;M0uUET@|g_=5BD_$r5Hg5&JJHgB@9d( z-(b(;ple9bYs`5QWZMPQgA-t)Hs!~MAX*p=_KtRVS1Q6Yb*+59|L}=!v>GVW8X0D4bh*|Vm2y;o2C;^dCQfx&Ac;P-FpzHlf5r3DGdQirY za6>?}Bv%AAA=o778}5H7procFxqz8zpAO6nmWkomcY%dC2m!?V9AkQxM$E`fOidv} zCQ%V+K7v@=5K#m_+PeS72BN3H26^GTi{$E%JUrwzt*tEd)I8q-#!?9I-J}=GOZH)5 z8iXo@!xi#2(js7;VaWnngPg+R)T#05e`r2F20yJCzras0 zTGVM+>!d^KdH2o?^*K4M0XLzGn~rWFQ8qt+{v^qV1T`gfKl~j_gJ#^&bkcz=LL`S;lzBuyWxdOUC4G;!kQipuS#T96FWW;r)h^|R%^5%;~K z_9ikupIIPA*K{mfNHihIgF+k%7{Vb--_1b@`Vl~}p;vekWz!SbV*>_s!khsoFQp~) zm|#JGz=29?CmhgFN5b#L8M>BtK?4rNLk&1D$WiO1*#TfB5cUheUa()jbY>w#^Ch|- zG|%X?Fbg_zs!aj~dDsV^h+}SU{$q`l1^corP=u4Y_|wyKbiGdh@S#RQL6!B#0Q?)k z9?OrV`=8mqeUr!-$-`TMUic>_U?S~`gkOL(5Aq*LKY-_hnNC7*W7}LABAZbHkv0bn zGrDulf6J~!F$5F3m*}M4nEG+h)9_&L0Tb%EOIazZN+5WV=*+PKCNSwRNE646=^!{I zaO%~ixviXxsP1!pE~BflTDM5NqkmY4%ySTq2*g_4M3>V5e6>N4z3?4CdH6Z#n*8kUJI z9f|PL0zadQ(GLa=G9>uyyHG!RUHE6CfWzMR?BB}F%J7-*HIqdpB_JFpVDR zSeDBov`H8B^{|!thu{%ERF9{SsERf#ab>zMq8RjhBk_7TAlCc{;iAZ-0M3NhU$V zmMU(-ay4Yn7GmWLfFKS(b9(N>`cL1#+Y7=Rso7zBfWAR%=1w|7ONny^S&W#ooPnP= z3R#i{1@!rC$>H6E$Anv5gTU+`XZ8rFp3)y+ftw250dKYBR_dbFK$NW@oNe8;3-Z_8 zo*4nPKx3Bq+9Psd$$#$#`}D~|0YIampg_2S5T<=~{e!;~KROiNWc0Fi-yQwy?w~n6 zftCwmIwo`pNZ7eTcGTq77dt!m=mB1rOTSE}V6Y>x2$Q3yB87M8Y$as)7m4iw_R7q@WfyYvq=@L9!9#fU>Xk72r7m0uI-wZEa^jhja{zNplHG!Vk}h5ytW`pn zS2Lfhu~QM~r3K%nCax*VqM-_~5^4!6O1C-wvLP$zEW+w&UC&B>`P2kJJC+ZAv zHJGX786_)vb_sv{_;J2o!+$JNGpQl)SSDk6$FybVo0CIp6XuL`hN_kraMvD-ETpwv^iZ{VvmrX#OG_JhmIWJnDI}8J!>d|_4GrWgbb-g$qab@1Jd=&l3T3#tm{7s@r za5is~DC@EUqvBv~G4Fqa6r(pgbtu0+iJlBsU);~ZLt|eZ>-C-oBF3BFTZY@>)~Gn5 z*{IZWX-*qJtu^B$T5yp~m$Ab|o9jO=?Trj^IlJudtalx6VN5%;$&u+xKvB5-+M7-( zVLIC8S0~PBmCdf1`86%{v`W(y1d2W6|Hkv4D)kCu{U&G81;uE+jw$BZILm>lt?myw ztS=ZJiMQD#b2;0Xaa@mm;lv*X!FbWy&d0B(=9he|N*8ve`j>2KqRP==ukUdFR78~< zS=8-JnR<%%if3b;6rIs&n^oab2VHpvJGJ<@V4c~SAJVrU7-f516VmhDD48DRU#Z_- zENuGskPL%`;K-=Dn+~lH3&;1vF02zHMeiOIZ`s$=s$;LoERvtOC=z3OrElIm@=VJG z9x>KIj+`JLySq|74bNs*hXeQKm&&lGo=J~+;>-!nIgK<-=tg_R5ToxxT zQ&mK}mY#poWXsAK9hzJ5k5zT%?;pR!4I=bOdk@HnCyJVQTcsT4M7{npFx4)Zwp-Ny zm-LSET;Uv6O=0JN+ofDL#z>0p#jjg?4+mZ{n1orcMjhlMykoQQGaln-K6aF2`Qrd;nY zH7%5+2MKMt(#Zq(~a0Nx>me-C~r3Ne~Xigy1~wrQx~4O$fTp~o#J(5OZI2MYtQdj zZj3%G<|Y?!G%sb>WA3dyc=iDGztcAiUm#hT#9+hQm~f$Y-Tp(Ubf%`cI0P`@$Jtln z@tT~clS*O8S2>HfZhq_^$(i}@+BOnd23g!&>h4Jw`+f&u)@6e?j!)k|^x*K~{d_C0 z3i||HZrV^}f}H0J@f;e8XmXxmD+pQqb}!=lfx{Tlq9i2jnBR|iXaftK@oS02RO-F*vNPFUb2{cP-d5o|1K#QlqDEkuz1HJ9PaVDDrNJV`v zj^YY9EI8Cb7Awkpl`)Gy#x|(^KP~_>XtD07ATe{O+p~)CqSCHiaNKxw+2+4;GY5xE z$EdpuH6Zn*qvP2*%t=o()l-=K(1(<&#~apyF}w{lWiUxY=nW$PB*K12S)A4MR0|CO z>9J5bBkqSJ8^SI)8WJtECI10cJY%bNb5zm7{x*(KP=LE*rnT1QIS_H@QSo(NHDG?q zim;P_KgEE@3rH24I7X{RD5L?kPfqtwO)6!^Sf6i<=%`pYWM$O^vH6BZd`~jo##2($ zc=Kn93xx)tZjg~sLe~HJ!mquFuMTbzpsj)gJ6Ip~)RDHgt4kZF7#J=8g~_RK@@&3q6#5 z?8ORgBPy$^%Fq0C*>dhjQ}9}_Cf|M_gm8Egp{LrwEZ~GYv$C#%NRYsp{zQ-@(r(=f z>F*R$3ydRV6B0lI9yRe5J90z-BX=c0x|jpOw5fE+7i|u8Oj9Eo*G6#A4xwxVdjrkg z5iRvM>HUV(le59PWLiI5`qYBoXsP?5Rj38WPz)lcPHz z45k|-`7f|OYNm28&yRw@;EJZu7w+P?2{3kwY;$|NqzH+FFssVee_`mkQKJ*c0>pFD zH-RTYaiy&pwJd^J83@}24HR$*_^4L}JgohjrRc{J0iSn#0T$w;4zd-Y!;oI?A4QS& zjz(-M#0p_SCK;}7Ao0B*7O>#L6Mn&B{g;bgZ9Il~si4;rZq0S@g1$SvsX9n9kV!d? zaatJw4OA_dl4PU0T5)_(cND7yz$c#FRZ>@B!WI;yj6H{x>|oRHx3@3u*?gy+55`Zp z|E@~pA3{>zc~}0z;9oRe6mIboc&5Ggw|2HF)d_cKb*1fm^JU$up$C!+bMs9I?!@o_ zT^7WpXz9&1?6P<-MuZQTY=zrq6=z<-MA6okoAj*W;==8U2nA@K4R)=x4-!lRgbzVO z#I;~$0cJ94!=<7U!=6CWp|Ht8*7^=h1UL(6qTZ+XQVe0r>8g0s4)OH32-@DTKZGkr z8HpxrCM;O4uVwhiBgCfE*Kf({+)$+H4@lzJ<3me%6~*P~42hlyw+A0f?iYzOz}N|Z z(G7WLDqt_}7*3K-kz&&3PxgBWL}S9y0OnvEpu3Y0=IO=}RE9W6ybtCS-jr=8Eg@nt z{m0L5#;3jj13kYd#`~BOpNYA-!F6{GU+0^jje3(EiNr0Ugl|EE2J;(Xda7iNMIKDM zu$aGQM0st$C$Fx~56?aDWI&Q+#y&#u5sZyeF{(m?m)kR^yn8pYRg3{xqk#r@>o(a5 z@%^ziV$1k~gwRQB7`eEP#+#@6A(Yw@zU)l4tF$Xu&SJ`kwTLTbsU9hrnegP21_-65 zeDmcrM$JLow3c(e!x@U8j0gP!fED<2;pom1J`7U!FBzFja-ei}!;qzR$Ag=}4=?Aw zLWgk{n~`1pIaSi@Vn>0*`4_TxQ!&CpKc>K!SetAIwi&vJq>f_~c*~9J|Rg;qZ2y|4p8rZtlQH-^-A|J>k-?> zZCt!{#m{-RrR{lfsj_+bcT9L24dp$Hf|Zj(P1CP-T%P);Q=s2?%R~H6_|}&RRZC1KI;T!WSf6_XZ)N;{iPe5>Y!~I zhgB4Rj9%ka33jcwhWWa>@iHFMJ2HF@Q!`%Q!|{DY;KpNvTxCY{tPS-4is<%S%I;(` zp!7M&@pJc3X`qsDDA%)@Xr=88A8WVL`J|eL<@s=WuDM1*gJ%8ozn{HmhLJ>bc2~ zqS$&VpomtY7)9Rrr=D;Dy$;Cp*fu^N+9t%QtS#8nZuWd3I6+!;e9h|CKQ7_@Vrz@j z`90=TdP#BH+k0)@DLS%fy;fx6zNME1%BP(v5geg46j*z`y!*z_0X()X11FYXlN%%K&idu^xAcj-X8Or&HxeWB08 zzm7YL4}>bOdp32C7{24;PLQU`HE$l{F57x(P>ru_N#1I;;jjK}2~Mp+Ipt!dFh|;Z z22F2!=nWSfKGD(MGtpsd9FrG*$?(dr>5R^&$&(y3xqXZsSr_O9FAG;i)fHA5$DQx5 zWOLbIdSmBCyB+f5hdB%v+XmHcDbyO|+n?{(aOMmhnfIdITg-;LV}$w;f)6 zyUTSy#p?T?pxJLTcEC7lER{Z|j=f)>xeTF#MIAlFRz@V+*~9XJSFZ!WVc^%q0S70@aHSVm%B0rD$Ps%H&`x+F zjB-8ML>EKGgKDSeBNQR1sr1L5Ps5x?=(rgR|9|e)!5z@$G=n`=A{RIR z867DABmwwmA)By8M!K0K#KnJr>pt`B2Otx*!Ga;wSbtV#9tBzc2kaI_7CES#okj4c zlLmV8d;QAoOFaB~zE%O3LuHiLkk#Tcj{PiFR|GXyVE38XQ01I6vvUnC z|8l#LzlL&=sG*RjK%G%-cM_x-0EI+Xi2DM#!2pFd_R)M0rU7&^Tdp2uGHok21G$3& z$<@FLA<~OO72Z^ai$Jj$sIkGXK0bHw_>N&{kpQdWCx1PlaKlV^u3=r1?7ILq$w2$j zgIb_BC!fA2d`9@(8+LY4&yRsz&xL9&zp!v-J^(d7;CvLZpwRROana}iZ>Si7RA}+A zJ2eO7ANmX7?DyMYNwe_Q*xS83 zT3pKD@nKQO(J-3aB@=W_`Ji9mQ6dk#naDD%&ViT;pe=XnPXcTK z1#k;N!7s;n}}gsb@n(R(G*ry1zOne7+SKl7CxGef+`)smnF&jbe5K~rQ!iu#$nLkvgib0 zlL!whs;he*9n}hr7)K^GG!tl2`}=J{FS@zsY89zT3RT2tCD*FUm{RXLQ6{s6c-PVt;jblb8zktTMw-gPl6B|H`mVv+Lpw| zdu#PC#>Y?uPV9?V;Q)gGtro0Tfj3uu@&nrjrBT&r+qnPWu1X7VSHK%VOe>QkQ_1OM z2_d?{G{d^p-eU1tOo{v_M?Zw-5{wv9n+y5nVOxA|1P)YF~Fpr%lYo z-aR?_4%?LHt(V~0dO^Byx?xq2xNT8{EK>(33E81OjnGAc-44#O(nOIW9cbz`;G~vzc4s~;v1r`i?(jd6HgqhES)CqB zHMm$}bVmB1z`()IP%Lh|M?e$G0UsXT#rRt0Bg7x40BcMHKLyj;C0W^QAc(xcnBsh^ zRDeATcGYAORH-7soR*gMZPGdoHDtO*@`Lq8@x7fDa)C#$ya}3}TB})H3?>1rkD(;p zvTD`7^>sY;2Ilz`X^d-wsrF&ydiO4^fz`J^4S3Wb08*0DZ`QkW{R&N7DuslD=may1 zi5OQ`G%o@@!8_{Jri7S(N%9&A?KrM;h4vEf>RM{)F) zUG!mF3Ks@MLNjDzTt-Vw&PYmmX3}M588U>Y2$vbXP%sqzEVH%ETJVgdT9=weMLqX) z@tY@44lA)|i<`KOW+mL$)j84g)s9=_hixz2o=t)kan2#(aVZPg8@he?3aT48?yglk z=L=A(_9+$jzbM$JoAvtm2Aqq=1buPqH$EeC@$q+Qf#}Dp+fg6 ze{{_7w@>V8$&g#8m?D`H^bd}F?x{FRQR&bc$CS;V-)XEne{;X%c*#loqs%!~GjVNl zJ@+LCKj)mf$tXcM6web#yTkGm(!+SQLN{YL`|GD!1>9>EW0^-KD6(p!U55$p+;qWLsK>eza$Q={Gi*vFG$~&jP@xqYg)1neW*Z1 z@AzWA-9$~oq`xpMS30wB%K)pzZaIgU)4$xd=I(8b<5yW}%h&j1$f`3>B_s8jJxP|= zrYGdhK&%Pf@U?9>uAQ$vYQ*$4GOm~ZO83(*;XCdrQAX+CPZ~*iUEI7*eZtWs4Vcn6 zbO@z8T~TF&fp-RK z!j6$X(I*u|tpCH@?q?%4%MbiT&=X9q+t6(PnaWG{wfv$hrlIj(vnU*AwT9vWLFNsR z0ivM>#cbb)kr%!~+a_1qLi7x4>$Kb!m6F5i0{%vQ!bfkPz9)|)l<`fVy3Zn{A znh{1Jicd0&B!Q_I_rr1`!fF5_K2neptq=xoNwApsjcysa$e8!IU<^kj@v!r~eXGSG zpi{9GS%jbdOlHklQF`!P*hH0;s5JQ4BYD&FJ8}Zd8wd#zc970E%ZV2nQHDX0Zl{0}s=a3T@$;&YdmDRYS&Nj{$#aFpGH)5%eL0Rf9~u zaGR?78*Iwx!Um+}PTiHoNzt8V?f)<0*LKXjMs4&t{+pc%NN z?l5&SgFxz@bs;-F4G9s3poY{hxSTB8#%hMC@t9a6``X72>DIp6_Ito0EGvkwS=sF>=^sy7T-%s zOlGD?jK;yeAzM8DBdrJb67$|SKfrJs1HMy$Vid;|s&>kHivQN_r)u9cMp66q zc%kI5)7hr|+0V9qs`>>Yjhg6N`Tz z;EsSnLFz&Myza{urtCCLi-IxM`|UQ{f9_I$V5%w4VFO~KU;E?($GTfEGJKChMzPHXG9KY%FIsM9~0H`7ER-y~EH?Aua|1DPbX`S|z%b<||C{>00G z2wFVoaZD1=K`=wIb~P4N{iikd_zp@xcHP$fV%+~o`D2hS7PnM~rp4t-!8BxuL3#7- zIA;^8YSd2bt+v1fEp2SvLvs%WG)?>*J5}TB-w?WRAOEkC(%8}}`1SbZgXOLLH6@Ve zv_7>=6GsUlH?8@PtsAM5*lSL|oIbp*Hfv=30N1yhn%m+4%sF?n)`}Fn4vzT<>($gB{+;mTn9IJWh%qQ8Z z{br{Lx2l7%djWOjd*KkfyRK(>yk>@uDpDP5(SH6YQ_*dG!v$fzT$?;UFCOVbQ!jkx z|2(7czMWJMBqkBax+!X2{fhBTXA^@KudtoIwr}iOb~2<~78h8Pph?%BT(z>#qIdno zLw}wsRV{Jwu}rjRisDF=Pj$lXX676Pd^Avj)-V~L@io|Xk%6!US61Mf-I$e!Z#lu}` zK^v*rgJe=?PqJ^~8)6X`bB|NmAR|fn)~2O9FU-MFPIr!xZB^_7!}n>bXqH~vm^b@e zm(vZ5X_T5d*%XE{wCs&r<9Idpm?iM2C`u@86g&(ET{oVG) zF|!?v;5Dj={l1&tFOlQE&8qEDUg74Z(DO_+mcrehtixT3JmKw$+$T1o|^8$l(dq+3Z9P)g~N5EPUY>23t1K@gP=1tg@q z8)*a-kOt{H*ZID`amRSa<&Se5p2H{hUVE)MpZSDUPFRYN%n`wZ)Aq*Jr%~Wtfs8uBj$8JbBXiJx+Qh2h9^b5+T}IyyK&Y&=?Z%(Z$_oL zKQMDp-Z!B9F!Msu^|xKPfZg35!+|IzEi}%f@6D|n0H8+JC$sk7TK+wH{|+A2hnm$- zqkEuOC1|Ld-;9Wg>c6Pm|B^;d16c10BQJGpa=)G){H~ceJosIgaTd6|x_aEMJX=z> z^7-#?z54z=Pc$_H*WZ=hFIEYFDvw<}15XX{?5|qyv7`zKZ=85|Ug}?^qgmq3o;+GN zVgyb?+z?SKTyqU94%G*E{~sv?tW8C^HX@ zr6hZ5?W8okbIC^|0RZYfAmVDmpY`5dW8f7iCkaJ@B=mG%;?x_c$@p*u-gxJ8yOefmmLWGf7$d-KFwL`L%iO*8Ehz{zHD!#b;9JaaJe3hCk(g^<$k%iWo?0GZ1K> z32I|d4FOTfsMZ?znr*-~gaQV_F8Ip*+kO3ZtKxPvf#kQA?#!`R4HjlF6@}IsMREf3 z1iq;WK;kGn1IApyO9WR-u&e@C1UdcnXVzKFO5Y~`4s_q(9tj51fc=At*v_G;DUady zk$W^x%Kkpa#b46-v8`!nI)@Yk5Mmq)mwA--3XDz7VT_?1&+P>q$BjHY_m$G9;9<({ zEzj@ZMvYu%U;=a8F4lP9^Rl*j>Ey#kqwqkPi+={Y`*p+G>n`B~+6NBFf2lWW!IQwp z=F_R>^UCtn-ySgH0dIsENGUF=eg`L_g`M4*g!~%|df5y2?HnR`Cokan+-|l!xktei zPWzM-!|bw*zn9Fpvd1|m)$CrEGJDJRqaz}MtCagQD=q*0l(VIh!qk@VEw!E$>R>K) z@*bA($Mv9HNyweLK+!Q4m7qi;sF<=+$G0;0L-QdP0S)zuN?X3Hw`zwqgY^4(oPJ@e z0VOlPGYL69gL7h5#8X$i9aT0QlQ7=*d7aC;`C>l#m#iWwj-68PVGeg7Yo7V8H&GbZ z82=?f`tQPivwX^!Hl41+gx;;?SR5tPisBCr{crt#++*{&pb}b&0Kn=x>m2&+n4jkB-_a%9;oqoy&yAj zzS4?g>($6Zc;nMPci>lvam^y;WsJ3TpUpF|w+xs;Z(_I@&4XshdqrOEu;0dg+DKsX zm~nfBHrn1w5a&zK>{;uM+rReLd0qi;?lWR$H%8>|MoF!>5akKfDeipw@`v$Nh%NS# z6OJUFto6f;Dl!}muNdoO=Fo@4Z#HOSXO^BHD*AhIOl#54C(PCLGB_!+6mvee+vWv><@TVmUO-4urc&Irnuu_y8zF%fS1tUE*l`xhkdRr%-r9~Z!;&(f9i zS(IRV?z#YJyn`DqrkDpdM&t6Ku(^g+6vLt1TiIm4ezkhdPxP!c#wwgf)a)n*c0f>1 z!6>gtP#VU)$1@itW~NWqrjb;C`f&KG#j|InmrQ`QwzpVxw3rJrB5pzBboGxtWw)4J z$RwA@hL_a7GL*O8==j&&POwe`JGWw#$@IWMlc87!le??42?|6^?y1f|%kCMESlIRcj-W5Ci zx{GqlYducD&x;3Kn^zkf{OX0AT#zZsfv%xDkwAkQE2oJOQ7- z1H6Sjp~qCya40P-Nt_9XhKBB?_eFh$PJ547=z(YQjrpHZI=WE1=k@4oXn^M~V3x(e zm;n1lla&(ZGXwn*F!AD(I2!8#N`=!kew7k|8nEKGY{Rw=T_?mgcn04P_IAtY^I+V7 zu;H{1ADUPbAX`Ql{`z(EN{Y4Xmwo!At>~gD0WInJ4;XcX4PcE@48(ANM7!WH0w5Cw z_yN~_bbJS>h)V$XeP@0e7VB>9u5I@?&oS_x%6I7Ls=^%|9Mm<3PCZv>#;V)_N02C4 zG};eh5n_4yPLX-lyocBD=Jo5+1N{6d|AC!g{$dCRt)9ohFQ_P>Z$PGXM_)+JS_QsT zR3v~*0Ek6_pFW{v6IcrM-SjEouz{TFIHd|5aX$wwvPm4^g23#~POl6BwNL~E=J*51 z35_FM4E7*tZi-Sdb40E*05cKVEQp@pbGeM?+a3P`xT&rWJ(Ba>mq-CA1uJJPsWb=| z2eN)R>d5w{etrrmd-G#E_U0N*-1lkYaIxk95+>x~fs|%Y>VW?QL?j>$;42tS-=frj zV3{}hD{y~IR^6OX@Zd+cy55L-0FT&>_^0Qfu|0+9fDJKKnC*O04U@t z|5CCln@8=mIOi{DobDUF(Jy=2HYX3OGowRf1x*-^4@YYMN{Q>b-sS$~c~f3>AZQ@Z z&OEV1>z5HsRNx`}_{nPB{iQw|7_zaLLg;KLILB4J80Yjcrs$20>c)Wm4kE; zS-(J0lm6*bJ1bRH>XHYrJ;1T(HM|wd#zNjIz;&OM(=^jHr0~<JIS7>i>>B;_}O2-`h4^>j6u?a6Av|OLiw*z$*pj3BO7vb;K_1ERIhTg4jv%&aP!IpX6`z$(AW zkn5X3k!%uszc;k}ib=g1YRk(Em<+#TD)xl^V{`4!S0v|1>VIypTW-WIxIjr(J*X^6 ztjtZ9cg_BzH;EwMPC9;RbQR9KZ9+^7?;X*fS^D$Rcdz2usrLR$RwY#>WRT6l{(?p7-rr%KTB>8ubz_hx2vQZlk-1B8*q8C!~^=a zN{B191-s(N%u;cK4xVJlGkGY>{a%s`r);e0<>scpN&g39O890Y^Z7CjHmAYPkL%+3 z@)X}4_5ayAFqsEI_gGqUnYomsZ@_|D)vNX#PXqSbAIkKCLISD*^nMra3{Yq3&q~_} zD)M6m&}!K@$OXLV@A(qKJ(srIxZQbckbZ5ExqtB+@g0oFN!;;MT2`IOymt9>KQJRd z33Ye;!~ul+^Rf-LborY;i-;E$MrJ4egc%ACcRq&A@D_YFQ8bx+A7cI;^M*<9rk2m% zv@cIT;l%rTj}#X%iThEC&1*``YA=i|uzz$tA3vj&w;57n*v9)vL*-H@d6#rU)qzCM z>QI&bDXF>)3pAHll!GFKncBa+XU;Jx)d}dZ+IU_2Ajzopp?~uoci@#F)?bzT$!U16 z?CjGQf-P={B&W(LB@L2vyL&n}b3cr@!EgYhkv zj;r?j`}CD+e1T*9=Y@5dcWQ_deB7e_<(7lBNXW zMQJl0%O#D@=C@zV{7jSBmuDD%-`eWx!H|0`E!IB!1r}XdYd#APRz-W=?+V)`9?J}C zR)P**iv9c-XKxurj1Pt52^|Fr9o}O;*7tdO?b~82`&f3V&T8TCz_I~vAf0xCat3SI zEYHN^4u=qS8$B-b_=^X4>!}f|OgP#3l^<{uK{o+}PQ-kH$N+FbfC#|<&1OmKhWks8 z?;z$9*zEx1SN4+s2i9y2%!9yOC4t-UtxdceY-Mns0X~N$-N3M%ajLSZHwGhi(D(u{ z+X&O2mr`csam_E?-K%Z7tH5jq1YAlwI*HXv$&djGV;>kV0@ob|T}b~-fei-=8fwAJ z2U#}&RTD(;rRQs4OGeHJ8W&Zq380K8WV{ODiNKhTNJt2zk3Yqlyd+*-@-a^$ui>ky z_(fI@QUYvCD=TSD&1~g_dQgBMze|9*VAjV22qR+bKbXq|_cdUERy`ZVV6MXabA{c6 zW}$K=H>GD*lUxo)eNfLJlBhqeOTZ1?^3jB8t@0jZU-b zeIJnM;#5L8jhlxDC5v0Xc^BYj`x%ao6lLv*h>F@|{sEpZGQ$N6EzCNwS9%4LIH-sS zXhdSVySoS41CsASxd;{>9T2SsJ=7A^6w8shhl}s+96LGVsWw0)SrFcgLBJYl)#_oM`BcZ*-DX92<)*|4z@woc(}tte2xjo+`%5sZ2Ng}V zO8jGY`>h2fS&h;zaY5Ukil9w`g#%efJm%fhFhfV?K)D44UysB{1Jf*ZE~$%A6HTnu zj8=fb>1q(?Ju(lFq}=}I4q-)OS@TU%SU?&pe}>$Hkkkv@XPKaxzw}P!bKY}EC970s5iDhKRF2k90t6+AYuWTQ49>Iae57j2{$FOP4b1f zuH)7vUA;0kSWWm8!y``6%T(N>h#`m(m8D>zje&G&P=CJzQ`zmwb&q#IttC|%Kpb9h zjaFxu1qKRWIEX&|AA5%$^lo4hh`=St+(1-E_`!Mmv?ucKzDb_KaMhgAdK0MXz`+sP z4iF5%L$C$bQlslbKsQHE5sc44mGuh@>B>Hze0lyS`KRHUe%@`JHDGNZ)*}pW9U~}Vn4u=D{7ekkaQelCQVoc$ul2Hw;D^!~U z&!q^SZ1M3cCZHQ`9APEJMLt_Tl5!@h%HSzGg6G(zq zj2djn%-*M(aF?*^>I>F4E-5L2q%+t^3Sk+gxcuR$`$Au3fgwvQd6S#PRQ9p#R#0|j z*n>Lv3Zoh!GVCI&wfd3_LP_5kPQ9#7Ck8VdW~&$V6J&~}82$8)1GYDEHZw(B)D=bu zq$lQH6*y!p1ZH(PdSfR(X%UcNH4VeWQcM~Qc&04LYbGZh8L}=fm_N%a-$`Goz|xmL zTO(a`ixDU70^y(zO@rm37L3M^Oqf<|upeJyo=?ap7X8w<*FtRdNCAIk$N!2-n z8&A`2TZ@$}^pv@ayjm1eN?wxhBv)6Xz#3_ZG!Jf#GIg@VNU+Kb8_psV&5yL`)7X;x zuKlgo2s`o?88N0FMv4YrN}Im%OKG;}$2&u~Oa0;+mldnWRJO?3m}N5V<~X`jw_J`& z&cA`f+7=;{o$gOK=#;!Mh4nWwyUo3M(pF?jkWVUFo}F1motQa*W-x$8-8D)Fn<5Xp zZ9a*eC2DNnTlNlFvX6G|%pdIfm27A47>-1*4un#!IK+Ow3gL5m?l6q*K3DvtC3tJ`a7Re-+c5*2gUZ?j!L%{G#YEjS5TkwlIB+{B7hxy0i)5@cmYPL+&#vEZ!$TGxxOVH?a#zTgKxuKaOj^*dvHbu(%}&mnp2CVDv{2W?a-9G^6>TDBDfa$Amqb zj9W67?kd6dzW9-gzY_I_F6Z~wu}>6U;4!N?M)5WUE)e&L2pBzKrI(W=mTHTPaA$mL za$E8)4=x^qQe7{>L(h?n58mnmIu4q9}t;q4*m^X@JIO zxS4*0>YVyv+rPic8IJA*uuzprm6CtUD+}VTwVr-ca>{*?cY^H@r}&3j(9bF?HRvt!?BO>T{ugAfy$JHG;lQW8-!e;bU$PYOyqClT?;W@ zkUjuT1|#5Q1HxOVv7pZTv9j9_f_qsF$cuFZ+8j)T!Icc`c>(c<4GGblpxK7v0=Qtn z4|jk`qXuaFArRo8%g$qZ72b;9o8L3<@`pV4A?Tjp7WD)EF`40R(d; z01HRpER_mG|G>J9kB0}OF8IZWloVUcG`L71v|@=+aXi%g1$_UjSjTK2rhwG!+jdjR zFf;*zz#N$Vn8DX6pPdJRC)(^G$P2Ly|830P%kb9wFU%3b*c{>7M8?FJ&^?-%n1BNf znc<^=U!aWu?2m+@i24tH%JL&Gkzb^~^CZk^fF79D!f@!u1VO|QN`3B8SmH*|)$Y1#+QJ%RI&4%4orSR zTINyNgNP0lY&ipM01wt%7gWo`_ybOOF1>FNAo0&!5rH#Wj?TCqANO6ykT-WV2vBV5 zFqKti5K#_06Ce@z2JpTBLu6$&c4NG((1SxduDe?mQagxoMhyrda*|7@Gyo`|POGC* zckkjMNiCd8tYvZO`%Mv&0L!7sHk7johGbxX3xr`P=lJYrB&!@gyK1JuUjdo6yzc;n z0+2G72D2a#KO<2Q5Jh0xHGRA&_#G0-!DPk-Yb2PiHTO284w=b-*43M-ptf=?3D6K#f8za zuU87uQ0vEn?)tJpGUEDh8#Usf78Ih@n!twQ6;!2d-QA$)Nr9G516*8vaLJHLAGzQn z-Si9?F(8o^qVNKh5_LP^@&{=>fSVl|zybxnBaBt6ikXRV^afoP?1er;<^zOjpw9_& z-z|t^EdkqI*znM{d2RY1HZU&!qb0+;I2B9pKfkXzlwZ8NaLf(EUl~iUY`RrRTe$l;A;3npN)?%8XOY-vmyak<+ZG2nCziczKrqy)|?NZl15}C*sepa35LOKMEt6W^?gJ`1)B&2Dz~Y5IOxc%Z{vP8O0TG( z)YL?EC$g%BMzr{|Cy?-^pq~a3DzYH~P=Xr_>QTE&(13Jt;QWXXQ|SJJA)I2o6!+Uk zNyLSJp4;!InW;i~;Q>SHLdbu!F}TD^0wxcE16qSRj=CD;$aDvrkg5?pc9FRmBHaK} z4LN@)MS&e|)0z{6DZ(GhddOf$fc<;iqZUL{o%zLu!_jcEAq3$RP+u5zqj>=wDWIO6 zaPYf2hiSx8MFYZFI2NF%MxnZ}m9I^Fh2b}hIFTk2P3U1&)a_Ub+t2(Ln*&z}H8IE) z4OU{hB^6MnS>UL`DkP&h)94S%!T;>gEOPJ4l`C`G2dm;K;_j6=UFM=fqK|VBE~8<5THd;o^Ak=?nsEcd zfeP`UnJWP_nz%NB(GC^SSw(nv;{s>})E$_ClUhwCGSU<0VJc5+q-x{H9^d#><*st*;x>KYvlog!mCoz zX%5oOlY69r*n~`Y(ot_R6=k<@EpA}{*ySN*yV;4~q)74YqKIO`s;n~KkyQKTHjl!0 zxe2{2342aHy?YeQ_&$Qon-4`clpP}ObOka!X(sM`&1{bMF;xDl;wMkvQcXb>TauEj zOx@)*Y>sXHEM@1Uc9W|Y7O2lN+lNQCQ9|8Z8H`a>ef@?f5y?13%aB^L=!jfDJLQKu z@8OT{Yv0>womQ{XFLjH;-)9-+u6xlBq1H2_N z>8iQLmCipTOhysBZE{;!7lpcXod{xDJ-91!OpsRM`~!#fcgB}bdF}-%61$3t zXDFq;@VL_S_hyCxfERY4fD$eOFY~eIAv;9O37Lv3)Z1AKss^(G&lqW^xkpieQe%QXDz7Z2Tdt$jaB^J#01= zon#z!yU!|i<}hBV+&NYCI`QoZ72*M=U!=aKHOC+gKnZV_mKIMQS7+#G(c0zCk%;Auw6UwX|92x?!;Rg(BfPn~F&Mm+`=eG|)l8b;PK#Nh`joiLK zBMT&)GOs--wE*Cy14!+{H3Qwox?P=q}L4sE>JOW?_G^+3omx4I+3;BfPG)%yEsoZY|CUXZrE_p~ke_oW_R0@d|u(5Zb z!y1w>AV(|<#9XWX?tHrc#|40Q-(4~HTm&vCCVb$=`R9O?{96>PRRYi}F&iu$pJ6&6 zGkEyO;C%$e%_~6kR(bABKL_3_a_xe1iVvK+{yRiMC_#EH8O038ER=RY(%JwF-5zx4 z!MCY8ySjcs`MBsheMAcm!pOcD)$|ZN#~Aq(>S82c3&>CeM+J~bu>u5xJd;73l#0w| z;V8X%lNJ5~NWR5@-oml#`Qy!3FwB8E3IsB(ara@TE~yrm5228Pc_BF8qs%UM6JO|` z{tLU9axG9X34^!_c*8ri=CCA=|2o&ojJ%|}cI`D(Y6w?@Njj1<>pgF}3|l|cO@LqN2fn0KQu@oN-N$gfbBZeh88Wfr=Sg1W-wUQ|6~n zU-YZuo}B#AzP-GyIkIxmX;d=$qL4I?2zXF_A~$NF8i3C% z;uyiY0}Y_{^gAF_!^QXx4+W};5PAhD8GLV`J|K5E$i+jWa3GTd5Dpa+`0=7aF-$@; zK`wO;&99*oqgNt@s4g(&T@kz9`2_%dBd8wW7ITAj4&=80UWO=KJ&6&(pMt&h0g5V0N2@L|A?oYnvg2BKeO{=o3CU(b^0+WzBhD9^$`;90X*R=@+7V#3oD z7W474f<(r*Z~wj@{czE)2J0xaBES{HwrcG*d<_`-gUj3KW%_WsW27RTWjEMIfo1}}c^ zkjp_^5AK0a7g{6tO2Fcx9sJddw7Fqu76zsWpA^b>E2g4%g@a9A!3YM~OrwP~4a7Qt zup_YJGg3yj1uyS|^Fd-KR`H6`iE;E6? zZ>-P|6Ba4TY=91s50oAt9YZ4w$cclkaJb-!8t4pI7$A8$ySzLOoHQbU@>v)6&kcMo zyV*w)p?`ob2&L#jSAzBpsQU-w@FpP29g$*AQbd3X3vC#{JH5lhX3;TlgnR&c4Rnn_ zlMn(!V;BzrnHaH^}&!O z^8kZAw3Nfrey`8leLRqB%o1nl^ zu*}SH-^*Mg11#!)sv9(bcycs8@(sc(psOnGi8=>HT;`x46uZvkx*rHY0;KK%Fr==d z17c3x>4R=(cz_UQ4&B+`+A|O6C?HuDmTFbv0-U~!phAH43H>%QR)mRGNrAEE>&LDu zaK?i82Dm#HX5bnb!#h0r{n@G919eM)0wFCm3M={Yg$ph;^nCF0C^t4ge-Y?CUKxB+Mn1aAbFLh6g;K0PFpG1RB=O(k)(7A0iwPi=Y)Lyw z#=*fz2_&M%%;WMrpVyGKw#)5xi7`hxUM}K$28XlpIZ{RO^SJ3PFV77UkxVdaFO=5h zym?yQ#C`NILe1;jRr07eNv3bRs#U&_NVQ*4(bGQ1WJJnpL7x-M?tA~q*k#N{{_{L} zD|MMN8A^hDUwW>3^GUwISInKueeU4B`d3s{r+08f(6ox+m@{+PP_b$}v|mkftL-*% zAW_i?6-@^d)<|>LfiuHNj`#9q)h-2^Msnf@F0^}Jo|EgbZd_N4z*#Om5e}ip`zGHm z#Q9~KBtUTa#j%*Oh4GhL2Mli_7LxrqtD9{ZUGu z$)<4dyT(IlqzOU;TmtO^+be(*k53aBf1F8PZ|-~2*xR+FXYF6drg2Os*I6csF-LD2 zU{>z>&LL8G+Jz{M-VkHvZqI8rA^OTEETv+q?flVA;cn@Pmrbk56!nP5$y?OD|K+RH z3rI^-vyW{nS_|Ovp9KXdJ$zyPHr?y9EbFe$)wdYH`SH4}QCj%iLEgY)9g_IXHtttEp7gU+vO*{ zX%N|mN7CISHcutKvv{3F<3!rvme2`AH`)_MN0z9N*7bH4-P?L8)?*kU{#U2RWxpCn z$6s|PBB)AQO8GL*dRS)4;GtHfQrd-=<5NVLFR7Tro#$5BoN0oT6Ut})P|$E@HilcV zj%B=!;hM>hS##y*w~x82Ze;Mpj)&OjsiI`PWLs>+HZy;%{DS21Ak}mpO`u5VKSP=k zdi$&(Ax@ULGJR56+2bh?coVpWl=>qTzq8*{ z@R}%wtLV`a&w$Us(k;&ed_dw#`FN4y_kU)R$`ITIC<`y7M8G(e7-s`27Le=TLL>HR zfYV`XO|zszxf4)--TskbNreK|&@0t~+EC5f$vNN9tLbd_$ypO@4NU;^JbT&`g#3@7 zWZp!Sj>NX)?eg{tPFRqU$ zKTPu7%Zo5gP-kz19WHo!{cG1L=<}XpVrsF{)d2@&mB5k{2CSyxrgq5>0Y8Rs@vrHY zhZMXlf=Pg$4~cTEH96?mLIs_Z5@c|(z`z)+K<{g54PT5v#{DM<#xG&%yvYnms_s|Y zUlM2A60hNm1p{ZTOK_D=C>Qyo|^u9JpV2#svdrK?e$cD|J&*HmX%#79U8zA zOp2Srq&rC%c|OBPdi3HF_<(GnuWR0OX1@M?e*R;A(^)lNiKy;}p?sH;8~an|9}xWT zxud*X%{<%pz(Dc!t+f+_@q4Zu5z^9cgvi;56>(U6Br*Jau5dR^Vn2I!uiRz2MwD-V zzr|{p^Klf*Sj3$vw}TBSMMv}QJjbOZ*Y2?wkDgb_q!UB6Ar!$Ua`Q>eom-#RUhW{+ z6nvV1msu!7m6+WI;`;}1h{9Ndw$zU$S_yd%Au~%9Xb9bGf-*K_cEN0D`+%L^8KW{i z{X7Gfn3&jWrY7}g4L?9Xj||rU>8}NwZh+_M_=&($8HM`^PQ79Rs^@V>?J+*Ay4JHv zgK+3N-tVU=eYg+IXH}8kSyI-0V2!fvZkD<#M*S-{xwhE=-=-nj{r;hL?cwF^#jbfD zGF6JwDa*t-Z8su;=CQM~5s#`QXj{NRz->i$dX6mh9~KxVL4e$O4i*?!puQ^jW@P{r z`)aRj$~dlHp9A58lm>aX#F-lmU}V9RtR@ep9Y2uC7?7Mq*(-`@66@pKEvm-Z4)Iq? zK@x9U`J45HzByO~B6|b+0w04iE4YJSDE3m|n?muqHBWaI4=Xj#JS~wXfnt@7RjMZcmp%op>{^rBCz=-tz&{6 zJ1ng8TRE%#i&)nxvB}%@{~guWjtmYiOb-E~eG{6D&TDjB9dqlRXQI%g+KI0cF5Ia= zCLI8Z5dF!4Ywr+V1mduzA;&l{OOLx%Yd0-U3xaGg`~K^uSlJ+l z;{~G}=){sdPY=)*56yZq+OV0U+elS zn!2w^uFCGgj%-H(9mVQ&ZK^Dz6h!j8@0!~`+Fn1cFW)|Jg8`j=@4vy_pMoCjIIT z2UQSEizgLLnq^xO9hzO1UO4m2`swQX8hfqGNvL>?3;HUq$HhJ%tl-)Yg zc-s5*%0h_lXRSPG94YAbV3eFB23eR!5WQ?*e70?i4uH})JHmILpJ;D#|nJC0Y_ z+Xuiq5q9cCFJk>q;!o2la=*8>nO$amTK?wY&1rA(F5jw~?>_&ytJv>KSRBeKUr%+) zfBI@+K=S1hA6fYI)Py+i`jb`y_WLuR9&UCydtGnk|D~Gyo`>`5mZi`bbA`!ogUhp* zTYY>Zi-~T_KgPsn))emD%iN8QsWAD4*=`kiPWwZUxu2g_{#ynm!32}NMQJi6h&$Jh zmc%Z76J(Fy`7uy{|IWkNU#tOHb()`M67tgJABNZlFD!Ctb#}7IXB=uUyl=&NWEp50)?;kd`ACxVq1w&Tn5ndMU-HgVz5L6W zV~vWs6>hvTw z#e5jmCo$4-JAGoWDKK@7a0P@&65{S)U$rWFJ7yj^W&YGj=TZBc%{Gr347lI_jKf(V z7NP5#r9y_Z1xDBB1!-dz?Cbmwui%qwwy_9kJN&v_E8Vpe{9J&2M=UXWyXyJQuj_yD zF7w@N<4=rtzY=FnKN5~bs=m2N?=@sQcgyVe@oVbz2)@-v6}`6d?*lpV^rCj^%~s-K z^cI;MYHp3t@tTK_+A3CC_3r$tNs`*(ZizFbc=7Tj46=p5mT>##3>Z``z;qay6C4oU zHMq?IEREOiej}0wIKRUtj{uTUAOuR8(tqUutP2WSc(kRV2Ug4(c->^hn{eluvz=}Y z7~um73esI_$4(wJ{0EjON)g4;3Q`tBb-h|H7D~V97!eI#FU6eHeZ@_-*Ru5;| zhnUw5Ue!uBe=a0s1DJL2ZI#QVx&VTQylkuLi%eH&lYbu}f zxabZcB?N|%=Hg1#)9c8c0eHwTZiTItGDh!#M$Grwrm#&LJMJ}Vit_p))q26c)2RE4 zGR|G`JWOYsQ+<6ufY-OR8M~&Ar4R)AUV7*2z>Ecw_O^||D-5l>N>Eh-tY|&Vg$%tI zBhvuxMp0riHPpBxTd(elZ48T0Z3z4dvY6G%%5@#8Rht9!9UfYl&4 z;OdLL{5K?6pUJ$9saW_dGm=|RvAnnK?_#)InR4ld2*2RKtBLfSoQn)tKrR4bGo*iX z!tCYu%8D`YM-E;;tDX1|b5QthFVUKl-ar4cpW(uiSG76EL-nz>2alRl==>kBiF#h; zX1A~o2F(r z&n|TNH_Hd$o()V+s(?Q&5F+TDcxl6zMjlm7?Zv|h?50hROx3#I z1lj13dzIQ+ncM*Xm799;>pAtkx-!+%tGmTtzrI6wD@-eZL5}VSwA5*U0YWj#G`{E* zAo6uW1z25x8Xwq2%QEsV1@Jb0udSKF!w0k$ z%21=PDZXe2+3^)*i3!SBa(ig-z5asmP3!)UT{o zPuNdK?IjBT{F>!n*gaKnNx!OiY&ES&nD~QhU+hVUUHPN$F^Y!FX@xnwkz=P$82G~K-+hn4;qE^} znc>kCiVGt|pJz(3Qyg(H>*7}g>?P9eO~|trC3mu9DA{Fy@U=Y3kuqg{%k^qFJ%vwC zV$A*F$2Y!uqI7}f*C;Of^Zv!*+2}MR&nB=`$mK4N(8KKMnsyax9#1Mwr0vq z`LV(NrJ?=fqYC3co)gRWM~N~!#e`q{4Skpz`Q3HFi$~xYjxlcvLVOfoa@LwevahmwjLS*dKqUJ0I> z?S$=$|NBlSp#!{V>2cgaJc&4ePGh_8&-_^Tsa^8&HYg6AUCEq3R*tP%*tM>O8W>U@ z88GpT^Q{huIJr*HIW!EliTk>IYz&qT!1cd=t-aTyc*3#NLqrcxZE}26iQm6!*fzRe z4;7=u);&EuHDt=%BojZ{ZTUeHt)t#5usS6s)9?UA0 z)o+Rd~A;`op3$=l(?7(m9IO0-0i}GTX(93 z$6fdbOLB;-en!^QUUjM(J9!d+lve2BG)-`ok)YN80cE~EUbO7hZi%B5rmo3NBTiR5 zpy-Fv*iw%pVR66g!K=5>jW_cyy>3I zHOEzvx((|v{qkORn-2bdF&oSuk*}7Csg_Q#%L-HlEsd!x!Qi`VxkIPkl5^P#@z@NK$7acUdFYqo2}zc{EH_ELsx zhT|>>El=ux%f@H$`*?Pt^-I?N2#3iCH-;vkn@~JSy#&p3qJ>;q3E%c28Q(OK{$1hz z{ReDm_^w|MEhNW;W<9QNP`6I6EAny^vARu2H*MAlQEgW@_uKgos9iAD3;TM~D0HMu z{wkXK6AfNLGCSP#zd4(w4Ok0zR$W5<)0^{)hTCpOQU4oirSAS&m@1fZyXZ)hX6G}L z^Lo+pJ-nqP;a-CaXLzAfez6+A$Qbu84N>|19~Zz`s(W5lOqNm-;#BU(HD-=|m%4HG z6=z1Yk+9jTI=8SvXtY+8+&AcR;GLD^X7NdakO~8*{)VC{&r<7_!*umm7{Be;1*nZR zjIJ2CGG7idHe@_`@hv1rY&!5#&1`(D8t`Ur&tGr9rBa&8TCX^7l;u2}K4nbyT7vrb za{&XE?eu(_Jk`vyfk#sZ^ECJ74NI$A(ud;%!`9Pl>1?0D3T}=}!ptmdP&BqwlC_+4C=<)rEhGx1o)1hnA_0_ouI)` z7;l?N6igghqhBjjJ5Zm#rl7{?_ljJ4p=^ zVa|MStN(k01_?2l?Whpk3qBLV0&_Dn8@18SoJ6Jn_x9cm5{b@yg24?E7uV%=o?ZO^ z{_%Id+;R?n`Ty%@BW0fciE|(`%Y&UZ_A6Sa`*?5flSq+;xqkoO`$yJCU{@E16O2jK zUJ7J%tKYvA)&Adw`+xs2>nY*=g_1Oom0#;$c)C`Bd|E`X#ZAe&j=o*~99{LsjW8xw zR>}mHd)R@EzWDmf9~Fd{zWB>Iux3ct^`w=Kgz+#mKN`oV7~RFh{pNE!o0lw-ld~jC>C#TL-^~}L z8uxn?O=|3Ja5YN4ukt+#=!m~oP+{&>C&q0Mt3jQ~9QCDh@t5O#GmD;rTgX-Mo(Sx7 zpVSp%3G3?UnfOe29Oj9x+Rl0?EbwJXB@$p`Q(1ZBhYzb=TOG(SDH9}6VDb%{ty783|w_x$3M-~>}o~KvKZ0}bG`1*3F zyzxwQ;>-F$-eaQpWfvxVEiMPv3j>Ql+DFH?QzP6NAT9MFh zL$y(zN6d=cO=7+xd2wk}uqyb?bzd=Uvx;PL#;Pd4y&FNAhjrmcO6L?)zE7MYDX)~` zjcfx>Hcl0hP*noUkWZucItXscWe#fw-?J#aMw6AdL!ybDZGu-YXP7Ad{(2sU=C@_R zqtVzgx!^u|$$yuwN5?+qWo)4Cp}szd_xA3aZy^IwHornUuYD8Z^}`NizBwOb`9ZmV z@&4@f2HAd_7bRjP$uSAiR}ZCrDJv@*^0o;C>S*5(Z&Gk9PmnjhrNQKa`-6_K_I@F+?B$}9Stx60nJu;?!1@5pN9=PKC-pV_|nw%+x2%JjBLuwjqe3VZb~ zA>L=uNZJZVhUx`HvWe*){eiOb;j($Qt4Nv74UZwU!H|9v{vx)jt8_Kq1kR*N%FQ1K zuje`sX?&7X)Zs`h=~}(LN6+=Vu~}9pTPXnRp4h#Bk0f`3%q45w`L7-IEr^=?xE9by z{<}=akcY)^p-Z0ce!J&4k}U1LqPnkdE@`_3q)|ySg6p7f^Y%&4rrYX;7xIFgBLBCo z{NKL`dbk-1>C>=rtVwP+3HHT|zuU*K>?xhE7h7;qdbc{8G{^T?^Lr3gbz&iB$Xn^u zzMiQ;TG~Gz@&1|N>IH*pkJJO1W&C`nE+lDf-Rj)!Vaf0%k!CTIA0wTvKS{1Nb^mku zuAF6!UiUvkd#7bCRd$gSmulM7KRGSS*({UwHa4|S-M22c#}>O$l99=nR4~8b;SAQk zzn0IKhOOAaZ4*4G>=JAO*T-->1z{P+b^M^|0Ctzu_u(gD9CC&fm3vgew_61lBR} zI6Rc*Ue2`UF>Vo$YsO-#+SPsK2UpFV9R!_U{A}uId&fa`ar_PwA6DtP5rury=k`3m zbc79>?&)M(c9m}3-MMVLe#p|#PjpYi*8i33lNn~!8?PlE+#!sqV0@(L5@K+>`L?}z zi#|Ll7byq0xDEri0{g>@xC&#eoCPW~22yOy%^leGU!KNw<#v$=W;9rdzV{eYE9md= zeCM7ld^$!Ef2VI7t9I@68`9^6JVh$y4$n86lJwkoU#5uMEhWBkrCR=$9qaU5cUaYE z?6~h4_t7Ip%i_durT^}0`Kx7W6{IWPU^HtUjs4jaU9Gm6vrP7H^Vj8$a<85Gx$Rd^ zmfZ*cl=+krKTi12vAkGOV_s2pq-_$y^_OXx!}KZfbKe7tgWZR}hn|Jn{W{G*`psK; zkrJPPfLv}iagI~C@}Gs(&!JbEbUb|Wj@iE2nFX5`)0wiN7E?Hkk)^F#H*eOwc_5vm z{(0nnMr;IsG^Vl9RcGf(74z()!j+f^T(duXkMtd=f-!>G`|kE$7Grx-nV@Z4uqoNI zu9RS@ry5Xi7>Q}Ubx}IUL9g#QJNwUZNU@8Wkl$5w72Lq5>-ueG@Z0lp!CX_+l>xU0 z3BTUUspA?wO>-=i8vcJzp#S}8wwK`ZLF9sJllV|o*!oM4R8kUoxh76A@0i!5N^hB$ z{JIySHB+SM&vL{R`ov8??u;1P|?=j?jYxnM?4^~$z!2_`e6~$Xmgjrwl z8#E~lhGekg3AhME? zOE?^3H-EW5>EiUaF+YwYniT!?R=W>!eeb%yA*0L zagcFsbD3(kwRLWH%}<57I>dcolI(V6kovkhcewE~N!#Tne9AlZGLG2^EoE%25%P3K zAMbV$*l2Mo9E2|h52YH1*%661`LobdD!Ws$jgSEOx{}dBfs3DAl8c zj)NJ#2Vn~KT|JhcKezY{l)g)=Lky&p*XNG>g`|E*cCBEgcIL@3+v`6s53cMg@2e_2 z;IHZ5Becp*+{j(~7E;pq)yMi%zG1oD6@%Fi9jlF*%dobT`&T%IgAQJ=?9rAwylnB+qP}1V!M-cY;cYpu>GuAVS(sfg=fEVkpyiJ%jiOGWn%}A;Ubt+a zJOmJYY*4exQ1PH6D1XCL0-6QEbQ&WdkxF)I%VcUx#H#mK;x)&y){FGV_k#0Mjd_vY zD2kIZ#D$Gry53#e#?SbLB|-}noGx{Y-ba>W>qAO(>cDW|Bi$993 z$S@X$h5>qh{jLKC7o}-3az|x9QLxog`5i1IZ>5d&bO{sBV7g{rX!Dj20|YmXjFM{9 z!7Pa?4g!~#4~ZYp$tbCX65HdB){n-or---1Y#ml6Qf37lxxd1L`wLVl_m?(dlu;{t z63Eh#)0Z;-L{JT(vtDF^sXIE7Trx4@>0?PHqEGfU5;t_iOkm=O?~eZmq_BP98wp8c z`lvb!ruKX&M5w;O{YPhr2E;S>=(0DIS|FqEU6gu$_KxqUK~}H7V)xRt(_BzN0_p@k z+71M(<D~0 zfdQr-_u4b_L6Or%Lcih3ceUrkombLxlQI)3JYh&a z(`Tea@x7M6JepjhBqEPVc8j}Q4<1^RF?u7Kb#6ktk}~!oqvJigY3;rQ^Ip*_I9`)s zn=p!sG-hzS4oLpU>jY6A(6O7{;@#t71LSec@vHMMf<-pzBN4Rq4}Y?%XDx2(m|GH} z6TCyPRiq|cNDMLvIFk{A_(EhZFN3Q@*zKOhldN?n#kiC&4bS7hEPOEeIl0-fsrJ+8 z7x>CC+`6i?BP{UkYlF$3$K91|!@DF`$z{$=jWq8di5B+K2NJpn%@7T{18zRU1c$K6 z?U&i>kQQ*q(mN4#bG6~!Z&wfRjmZ9ScIt4M&9q@95~1UqyWk9gt{L;Z8S?@`CwTWr zKrB&t@ZS7An1Jmi=011>VRv7R_k1!&pNx!mV$6h^Ps+d8x9ArN+sPX&Fsue=O<7MWG+k7t+;Z` zHI46%vrQR0no70G#08VQPtI%Xij+=Y^^~=Wj%STa0Q#^}i&K89#>B#rJl&dUT@^_*NF8j(}blbyiV0K6U!O)UD4fD zAdgOg^|fKta+BZ2uNMdj^$bkTSRAT1xLipOuf}O`ums;Yq`0&Y|K9rK27|>45q6X0 zK?@$WYumY*pLxvK82wJdCS`vJ-Ya$Z5)!Gc% zB9)1)>hSz~9VJIluRE_x?j^yrtA}Bem%TgCB|sGoAuxzjqu|j+Lb{!wMwb0|z*o;r z;OEPAgs)FokU#m-tpOIB{EHy=cIkimJw4RVP#*+xa&uuHexVY@Vup=ykRHfZs?_t< zXRl~jA3&b6A&&+f%G_UvdS>rNYej!#7y*UGzmU&Ttqjr@P)e0GP!&BF>FHi5PAR

!8q+ME4G7#nAkxRX)dn_isSSX=$ zu~Zzz-hMCF?A*?dEL7db>vMV-DY_1mUV{b)#16IrjQu;$%YKL6dn|)G8Qsha1_ojK z`EJYvXhi7(7%o8;BgS?BjXlnmae2NK*14To!f;bm_1<-Dw5k zenG^<@>lTMpp9tv5Ioj&pxiXrh=B@7y!rL@PGH)no8|p#Wrpxf^Ml7gG=Fz?1^`>P zKKiFTKZ5WZ8?MhS0LnjFTc1rXf$0Eqeu#k$z@KXS4IEgQJVn|qFtdRgfY*#3Nf}8v z1GOEF19J_a?8z`s70Bjcq_+t+4$-gI&wcW;`hI_Z+`ro& z*D$Kj`+d&qyk5_BIIIEFCR@Gt4h~+tbLWKM{dIxUj23uI%kW-v(;p}AU_NddbYF9` zv!I^J+|Z)IU@;8*#@uVmXGMTj;LiW zWOw9a8#!xHLez+VNI2owzPk7937*;N%R=3w|lqf zUIk)fG=eRhnwADvcnCe8z_rmQ9J=K5zM7UxaRq;&q993=;@wVYS8cw`WJEoB;)J=7 za$cMnSpUVT>Tp}2+U8^4^TXF|-o{SiRa+E1n#JO7Xm9|t;02r*Jl$A1T3zmBMe~R; zW6Yl%>HaoqLoh$wZI_mRTTN08CY*msb)cXKuOa`EN4kr1gcg?LP0TeZ&dXN&QPL0l z3#*Cg1ZJ;!XEvct5E@Wp3*T2w$T%`yy((*}N-ED;thpJ9n&7}umJisvb!$lFI-sRw zQ!Q0Lp#6n(dJWyf7F+>%MTsD<>0-!~#|e5t!``4_<@OJ^4@M=mo)gP@{G$f(F&Iz; zB@L|o6m*NWjQ*9;1p*ahMGS#>AqB8p&KyFap2&!`aSLbEUaFxtnBh)%%DwuvlqeMU|9m{@59-%pqy7lP5!O+|YZL4Myn4&Y(YfjO$Nm zX!tu&y=#>3JhcCQ1a+-K7v&Cj=p4Zg5OP&~j2SL)jA!18X0!A1@`_7Jlzl@ND}m^s zcYkwj;ULf$0eC4Y&grxJ?dkp9q@)y?3fX6W5dMo6Ai934k}Q3H>h46 zPiap(r#$fCc|%p;eFq@tFt|P2c%iiKFqw<7&QDDTF=C;^Di2h*@V-9|r4L;#%wQ=~ zf69tW=md~0Fn*)oy+#S9lk;q`1HNM4Lb>&aNzd|CS%4sLD|82s=lq2WgGE^bP-oYQ zJ{1lr7*M?E=g*ci#Y4kPuk&~vdi^OPGnsyJs14No*YVz}9;F1HU+CQb1s7{;PT{gUSZnEdEMdQwaU zN=;9X0l0@HdrtX}j*abGtxG*cTDR!pOyR#3JQ=N*^PjZ-5K-*h3gZ|6z9ALg>0g8!|$z6uf+;Y4djRg`%cx*lUI&o?N$GNivmuc zW^A3yqB(460Y+ZYHNb3GWxikvW;xa>K<!)G+bWPY8u3kTL`MQoUWj3X+2P5}r+d z&rVO=^fa4YfKXou3UY;ImojsB`T)GiMdx_rJ*{LQ@79pZ$*nT ziP}6r)Hs?mxz~o=BlH8JJ);6K0Y(FeERjYgHbbbPALP=-v+a+}zY=o$wl_owXN4QU zW4(W1V8-2TIdH$j;CTVa?)7Nz8x-^s?pIOo3);l?;^(`!{bLGcDRugs>K@HlWO-8M z8|t3;e-V0FshoXGR5a*1DBBr@ngJ0+Md{N?-!6)tli?Hf{FssigczcS#L{Sr)aaSH z8YdIm7Og}!*g7g{LH+^ocQA#`WEqLRN=hn(Zz1H~DEw9^Kl`t&d31AB(ef_OT+zRc zc?L{{Iz{M4PzU4N@dT(82OesoP7%f`m-?74kh=hYcDZQYq1agaA-#7?G2|%~ie+Y4 z-0pSSpQ}ges2wYBn;y(=TZc`y=g`^a44$0SVk66bH}3EVR+W-g$XID|OJ(fXll0rA zp+STBLgLIiaG)KiOxl11$v1!D_SnDrhK9;x$Nn6dF$RlO(mdz4N0dh{dFYpHu7+CR z2_Vom4m2Lt;8~9Tv^F#An=+|oH7}=n9y^9w@(}F_N43K4)JL%Z%%CO;H!;b!NYe~GBLU@KA}9hE?7pvXzn4s;)cB$a`cj@ zNtyL8gUnMX#9*$c=U{NHi{;H%^G^&o2+0R|jyCXUYL~alUpaH1;>lByR0uO$Ea#EC z!8Hguy+GXYMBKM6J(jKT-a=hlRq(+awBW$==?PFaxdd6ov)l(w*{;BRA+ybO@1xmL zBc-xADnjDc8yOjSEb}e9Ppij51X<>^(y_OX)vJa!KJHEw6}qmk#l>!muJ`uX73TZz zVb!BeSg(uPCIuQE8dLE(e$q*~$zrevy$+%E(@6uUIz$dtSc)5+D?K^&%Oyk@hCv$3!*$*`nyUp;rRiEA#*V|0Gsp=+YWRkv z^~-J>$l?X<%Rpt#RSd4=@mm5glGvqHX8L-bllz94~-s`ma z)G0tqoIIl7LW+qBH$-E%f|61Pw`AtLE@xpb(6Ir(g5PospwDUxOoC*dgX~*qKG~df z#npsS9H(%)P|hI~r6eU4&>?~a)_hIwNbzFc{13YY2hHHU1b!1yI5iZ%2jf@@dI`QZ z9E#$Z%w0QoUILW5yc903my8VCX*1Y}P|j`ts77HHAaht)m{78yC-tWICbj+@|0i55 zFsvBtV_t-JK)_-4Aygx0U2ho~t(`l6Fa7tGwsmhLYu;< zjfew#T;Q}_N>KRNv*q~QXq-lW$>-FWn|784I&y8iHtGuR%+ednZf=!TQ1#PaHt8X397G1^oyt*jU?zi5 zfJ4xRG4G%8g+yKhQ#{g*%XypWoR!{t-*gXcn?=R|jO<}HGfWH8X)lDA_;nQ|6Ee?? z7rPf?P`q@O2BN3Ffrj%94W)r4``lS8tACEU1bh`tOw9HILV;&Q;~;uY7WfoUUzp6tca4whWG>Lm6LVw=8Y$yo4(CX!Q7^s+KCGLO-0 zfBW3Tx%+`+7f7$`N4JJf>*O%X#8v*HzlraIu8y@8$#A#emntJidIA&`BK*rRmz7$P zQ9A>%5=)0G`B)qz2eDF*phIk54B| zSl=A_(~H91=+Z;5nLu?}A@ewOg`4b+CB^k-?;MgX1mFg`{UOO!Q?wdP31Vx#=rU-~ zAbzlnmoL-azdsEe`RU#c=!1Na4w{G40EvaVy?XZ+y~U7aSsMNECt$&Pe%|I0PMvN5 zG_t<)$A4=b0dWRM0cN*x%K5{)*Zs_Td&&!BcrjlJ0-|WuY{yzXy)1)b%gvjY>#ozP zwnY#z6p3D!ufQr;_++3zU!CK?GW&Ac-{1S(xns$ukKsR<042xAmu8+>x^}Ihp5D=l zs~QNqVISs3q1H2VfKdYgfs}GOURa}xCiBy{%wyx2%!q3`W#9=?Ki%Lwx15`9d%V2SA(lV ztD^#>kD+4j`yLU8+NVLF9q^@rK_3p+Mm9_FzYi zOS&<#Qm(Y0_HWI9xiqhT_?|9}|I-EG@9MX{mqOcN??bNzzsSV?y2B4-gG_aPwn;K+M5p5|3urHa}=VVB7E{mHaGrpZkw_#>TRuxuCYOcW^rCX|&m(4_{p|zp1g2 zD5?dPmfx1zE!Id##I~({ajb6o>v{jH1>hg0q-ZM+GqbT-41iaW|8OcjNSANl{?6CY znFtmkwh7b}z>2b>;z5iNAY$W{(>dcfjVVF8yVHRQrEKKjkpT}X=~(#!QNvc$2Wx9M zV^S{u_U$>1nTSqV>S)ai6ijsc)IszkpoCq6~~MX?j}@#s~h< z=U+a{`o>pAT@>)#8qksAQJMf z3NQ`f5tze~8;f>)xE=C)koTh}AX>urBGh^T25A7+D-nq!m_8KlD|^sbrIvF7DL$I8 z4$R|g*R2!kK7ik6rPbGE4*~^2Ob87n2oNx&u%_kNBzdL=tGk>Tx2_1TVDYx~nl?QU zFirg8`}PIi_IC589t1oVVxY3;zMTT$Vi+|!gnZ7PJqo#I+k*n={IBwQmWmd~4l;*c zo})cWPG9w7f5EmRDG<6bcw+-_V{v40@Lywi5Y$WT>V*(CX^zhPjXy2EcfCt=W1)B= z*C!rBYlIhZfQGKFuJ{>uwfwV3pOg$xDb>BZ3Z^E0HdzoJRs|ZCKDMYrbc*7Xqb@Rw zflwSCz7Wg^lAZ2Ph^e8uXdjg3P-2HGZ(`z45{J6be>sZ;whZ~6EKFN5GxwxKg(6c%L}DdbmvOunUL zj~@N1w{Q>$M=Kp>0AE$OusPTpiG>rsDVEV)%Xi1noSH#-BFI6pZ13xsb*wjF%I%mK zr)%wFcxzLpn7jysW(KZx5B6eSh`##vX8=-3*vP}pq2@QOw`~)L4xVRWzXU#k<4M+V z%G{vy^RKS&-A3ddz(X>EZJBQC&K6JuFyL?t9+YK}KLt{SQUP7D)6T9K!{0S#4$*j_ zHT+%pMCVQ&7+W*=2PTg83|HbD1Y?ry+Er=XxKo@|e&^4R)-e6zP~5%5HeMlHW%_DR zD3(Sgh#I9w)^;WEVC~Pm{#25}fMFRr@edygK?+DJdm!Y-jf(jCy1nT!2?>K4uVJL; z!7eiFmU#?y0h#PVL5}v&dii-L)Gy>8>M>5kWSlqx@ut}uuCBi9M-4~3cG(D2k-&*X zV0W)3v)E|0fKBeRV-Z@NeVu8p2FijGC4Aa>t(Q&-KnG%u(bxA>n6O!GATKG_dy+H^ z(H)~Vv$WQ=^{Jep;u?ip@gNK^l~J$-!h;0DgOc^esazq?Wb;W_Y8=9Me;Jci-TpecWrrJ|H{!WqSQe7`phhsJ&9FK4u4=jzCGxj~b6&bk& zN`-@!xpo?M8V+Ht@=24=zcl_=b4Kk~h+To)(4pf8dcd#IyIsV(!3S%AG=E7kSj;XF zAQ(&f>AQCU|Nc1&&pz%F);;If^G-hZRTUSHj$HgABBf)r)%gx>r{k~pPpr@x@A&$- z>u|^33T>e>?^L(Mtg=>jyj<7X#(ieju3k+Y6{6rmx2<11CwBFY<%&wXt1&!fWM)2j z`n2p+M$d|wGxLE4t0#=Lc)Pdjr{iv^ocQg{1GcSrnd{o0h%lhMnb_eHey4PsHe-e# z^B?Pdgqy{zkmzV18d6&Z_{+;i>685cM6*}jzAawHcyv24t72gSM2Mkwi6A1t-_WN{ zn>L{e`4jwWanWDx2SRN!wZOSIrCwwi=?>kWDmtaxrtLWoNtS$#MR0MGYx9haR#pp$ z$D$Ca-QV+lzYuQo-dgM8l2BGm-&NE}G$6xs_CrJnK;gd`xtS4DmDWJLPCRvi=s zAfP5HjA8T~!u81%Atc$`%-642;h`OJiAXnU>jZH949Ny)$8FoUD=R1*5HT}MoPrmD zyly0&XiMqTngt3yf``pNA+*t)36#uwzvAC#v^%A6s*U~|VRIOn3(5vGi*Q6D=ZXDz zy8h$4yDP4|mKiuupTi4XxA7~>!j!`z`Kfvq0cdx+%A11g!`iLs+(NWUQO=;c(E#B8Yf6E1g|70kJlfUwP?Q z~q(#Qx4%n&VKz`vB_&f&7R%6QxSO3BMH+pEQ8=}0j4Ahv6KRn7XxjK zX64l43d3hYPq&H;GT1txp#X3_F}jbg)#jL-ld>{sHd0L*)$Drua3a(>Q2 z`K8KBpu}M4(nIyoFmC=X98e{F^xO@HmFVTK`{g%jn56V84gz%bvKgO)RrMeu# z9--65KS|n_2huJ>$D0?3_+-8t86;8$dJLYXWoH*79Z7p7T3=odHWCrlFnG4}Su}_J zIo$XQhNhLfHH3$=6~KLx>Em{dL!0NGibf6r#(*3D$bVE?J|dUUFC!IL{c4|$y1E`e zbo(7Y3|=I>jq>}AT*fifGP7+$;|U_Q!6Y;9)2HWOH?CV37|qa@+K2({9rfMxieaRO@^|v#l|UvLaG&Xmht&TmOfZL(3hY zS|l9NQMolQCS}R=>5~dmw~7cBDJdx-03%btACZ!fS@LczEG;j?I-xr`j>x{6c)neg zZ&buFyMO;IVFh8WTz&NDA}!``D!dUj0>d^<*%mY&Gs(uHJpBQAUmfV(W{e|+_G%U9 zJP+v#g=BnEv#>!GmzJvZwg8jXDo>k4}JoSm8n(|05=u${`^6A^L z!sKYm(KG8EJu50KC>jq`c07D!5TXXZ1LBge8tLfJvp!0@e8J0LO9ivFhZKfO>b{3Atb3c ztBe;x4=hwV;cMc0u=^{oD9zpgk^kw_r(rW5{qGUo_OoK)f#&{mm$h%P$v<#6F0TIm zzNYE5PIv`TQFxsziX(p5XW*nQUw`~C#Qnq++otdKn-aU;p|CTS7qaP2UzBe9bu#+K z9(r6ws=}BrU%srx6T_zzlKs(1K5OvsiM%`7-pkmI@E)3N+EnA0_3&2o13$l|aZU*O zukd|NpYA8}q;TrO4U4MDDAWz>27oI_o%fd85^f>sZr*yUdG@1YawBEtUr1@*M~gCM zeS$zp+hug(3GxbAe>~9tA90U^cwsir@Y12G9zz)S%m%&#lQE^rU--ctoqM$nKy7>GR!{p z8tMfyAtHF-(tQ7nFx($vck27if}h}KymD-;(8dc@tL6poM_eYn2E=3_4@X=rTyOZJ z;Gbgl{9MgUKtsiBbB0F#(mTsWQGkd8R~{`wr$7dnc^A+8(+NUb=rwsw7a14sIMrKU z$_dt6)}`BYl_ejMBXj<|pM-aQl5DJSO-6InteI-LGuLI2Hy zqY@CLS#fVm-~Rn;sIh8{0Bk12%bTbOCZ2@O>+j#6xIF3SzwL+LE@dbN3qoM?Y@Uzg zDKZUn$kt}8&xF)Li;@Z!&B|s`@kK91MiOJ-i7cJymJ+0j(!^5X(8hhsmN8>ji!mM; z2A{~V-gQ7FZ*n%#)T6)_Qa%ENn4D8HGmqm*xeXQIy3Km3+)1iK&C@|PdPCZnP@&^2 zcI-f?cbK6tGq;;(gonJNBqJk-I+#uL+jmtJ6<r_K5(h92{1vJQ2&2G*wE(Q)JE@;vAwJ;3;O&6F@3%*@^a@qk)+T za@}0J6%UuGupGz73y;$Af%EPx9^jnx@7pafwktW6BOpmoht$p5_%{EqJaCXMW=@o& zR0MQIWAj$Ln2TadB<+f;B`7uV1s;pGG>N!as(kdyG0yP9?cz2k%qhQUi3M1khOw(( z3UnS0D(-elRkTo`$Uv3dw9ls6j=N5K^@3!}!KffTD+4Ja8IGk)=FR>~m7q8Mkfj7lx zA!;oqT$lc#_WA0+m;ZC(wcg4;->y;flFnklUq@q{i; zK;tOb!iKa^=?h+GFM?43>=o`>HW-vXFq*OPIh!ur%rNp5_FC*-dvEQwhwvn`Bxq5W z-4Z87Y&+9iXyYC|JxM_DJO9>1<*w1cX}^27DXBALBtc>`kl8V($)EH%e(ace-!g<& z4hVVpA$^h)kTucE(;|cb_U3IVAd%~xrD&h%plL4tu1svI(xO&kasvVKr$C@Cj__Rs z{Lo24;(15sxpO;WXp40`P1WZO)23DAT=$xJ*92wh+qdrY9N3Sxl1Gnn)gb@-?b{P5 zRu(ok*YDpSpe65(Y&1NwU%?Q(=@ddpM$u#=TK&3VP!4?_+yQdHqgdyuI2^hq!%ItA z`yTx^ZP(W-DWqO?xndH-13GT)ebcmlvM$4lu2(GOhoh%`)LgxQLBsT&*WoihX=V)k z(QCXPqb?Us6P|oVRDgsS$F{8Gpo<=+y3=ROSTNl$GMC7yQJc5$r=dSiS;P_9E#xY6 z_+qf}F-j0D$yRP&kS`s`I#sQLsVcA7~& z!}Rd)7Zt~7TdY7gTCVbgGB{MZjJLEA? zD@V_xf8Rzk9|4eWMdqS4d}+DpN#DKe%Xh?+04^j#rKl=MUdN8!-~QH~{#F+cKqFJa zc_DSb^1D5zw5bm42*JJ;IS?$_k_{WY(B^HAp(b@`n~--eCT1>J0HqeR=p;li0R#-$ z7afYEKpoO)G(IAfj=QvJ#h!rsv;8uZPnt9AEkITLBWnJY=<`zpl1&vr8_*w-Q_<9m6)p*8yR`m*1xfgMK{4uZnJ!|^_26_dTN(0AbtV-FGUd+hs>oRzskjf zhzWs(0k+uAlC5@Gr@5ljR#e9&tM3aW$MM@p9yO+kbYMf~adiI_ckH*B`*J}I%F4Fj zuM>(?oVCLLZabDR`B?70kYb!Gnm02XPg2sd1HJ)j{%&K>&2=Ng5M?RXU zufmy{fdC?49VGA}O5c@fG8TL1W9x-BS|ca|-8qoX&%P$tyAS-=D@+oz9tnr!P< zBo0cw7m48mMlYQ_PD~gOIOg6f`dH#Gb7^ZT$aGKPn=5}8*yrWTN@{9qXPM&RH`IRT z$37nt-Z3&|Sk?PK<1X3|-jA`{#P$&F)CEn~i~Mb74}Cqob|Q2Ws#?T(jv3Rbwbzfk z^wyMYxN4ye?c zY{!lrEbojOH$B^PhN>cc8AcL0xj91*G_N<`l?KP5@Noa|vRS*vNpMBw8n%=ELpQA6LkjN z+3j;#WBYZZBS9C~kB*6!=ssbj=1!i`yZ+T~%h*zfi7GTy=ik`)`}cn{oE%u z#H5PAc@`(`!+f^^UZ7g^$j>H@Eir`L_+3*|EGB8*hIoEl9x zDex>9rleGI7Ex5w)q3UTxe_d)Q_!F&b3!(!CA}8;f(=usN~O@GNm7r;@r7z;Ya&d zb?Ti04#Shw`b!2eI}{-Fku_qk76GCJ!ne2w#V|b`rYKJf8Knc zb34gu@}6AFNiiP{fp4}CRqQPo6k^5GGRG$wE?Xv~RpPTU6v_dovj5sYnAlxIL*vHH zo4#ea{HYk-SQIlPF7vX)lOb6RNfU7s({UaT-Cyey;Czy9k&GA8h%X>+1T;V{kWR2J z%aTsS|XA$m#;bMQctCugfauf8s+9!?sW*EeO!}-5lKxi_stunU_b9+GG(8M zr49G$PS*lC`@UkTmg13xYEQyI`El{M+`_LHvllZ z(ik4ZyeP;sDg)mFoU$iV6vJTr?n+;2_c>b+XgOb$e10!BR^LrGcJ)}d%A)n|pA8sT zw);iyj|&Uauj}%<3{>CgSG)(_E1vLl`)_@AZk;~;^*E*bz$)e)oxm&he z3s>5~WhoGC(gW2x9Y8Z_|8OyXO(0z`c^vfB6WW%_?AQ(gk3x?DOwC8t<>oufuzGm{ zvtZD$lOK?N-~Nr!y~M;IzAaFMxhvouzHD(<1ktZf2pNS-04%(hf4!jGldeRh#U#BKMP!dfumu_$n|ais|Cn#`XL45&Xoc!!Da^}rM$qe(LM0- zB)l`SIs2FtSRIT$nO5uo+ZN$;q_Mu!_AWP9#0!4@^H8VrFX}oBA!VYomPt6tGb=ma z3z@W2;cg{{&uyR~v}%Pcbz!!8mG$@UUx>PB)J9xF#xmQ8tyj$`j6Cd~Jvlj6NIZP# z;$>RJHxTI(*hduYKduO0sd%e!`KW-wy1%_rJ$|RlY<~>A^|9x?Hj%Xg)-i#Gt`tQ9 z`f8+;!#e7qJ!Bj(TDtTJtF6R-(8>9j|L)_gP}!7?cl!N%p=b{Ow5rdDqlqb8kJN}_ zE5dh((zM*z4xeeACBi0QNdwEwp1z{;uOrW!bSIOG(}e$j_WsqXzxnCo9siC353-f- zE0F`v!RWt6MXS^3Wfia5XeHEMzwOJ(O1b^>VUrgLvh(RXAtr9t_6@@<1VN*ytlSB{ z0^@|?7b3CLi0NZy4ljF8gUI+Wd*o7;vImhhJ~T!&>X26qg-?)sfxvEo`T0?hl$3tC zXWJHNL+{6SYLP^8c5fApuvSvhMsGyL-90T+hqG@*MJa!7DV ziVN_?Gt;PwsLK}fmHpWorJq1JX=rGEC0IkF_a!oUlZ;A-@s*jQI)*Y@fq;e6nYoF_ zP#Sx*r5s2XL6}?kg?_n8Wi?Dg(`_W{eBhpG*S6pTzw_ZbC@p+bA5cuNZRq>QnzCOB z&xkNtU}bg7f6xvKi^pkco%xoPDxHE_X;&$dnF_Aendgh_r=uE=pi zvzTeV0eVbuzI<#X8h^M#44j90g~Yu#S=vAmq?UU8h=R28&Y1psigtAs#tI>a{P1B! zY2Ssaz1guRAQN~St5>fs+24+(LcAk^Iz{XjF|!-LqPBFSc6@YoAHEuhl6h&CTE*cg zK)gZxG{UR#{=I6w0ss3&(2O%Ud@@>HFigzo*AX6WBB}Vl;>+cs1*YKuK+*HVWV=ej zRxzl?Z`C9Mk>Z0=KEr;l@*&;{*?@8nfyTu(cdE#ioSbIBjh2CT`0MmH8!S2gF_;s7 zkAEd-rXTqaSE!!B$S7jYm{f56jO_mD^XD?WJ~B5W^LuH-|MoRd2Ezt zIuSOWf0{(wn!Y2TaM#x|INu*4c-$`Z#)n^9n8O7M;zG=Qeu8bm;=f>OE1z zcQT~`;+ix-vVOO2BIG_mvYQ$~`IKC7AsY6G_+d>?%%V-XS&b(sDJ`s%M>vD zp`?1N{=Tl+;>n<-2Ac%4sDqXba$hilP)ut&pC#}{aDL#g!82Db3{(Rl8K+k~`_BdT zGrUr@X~CB-r>A8| z*5GW#Y-9k9%r9KwONi44hYbnE~&UC z)Htl{%Y2E|0!^7ul;UuYtIEIWdOu zi9toUbHjw50-rwMqz5i!t_@S91c}TN3_cK40*}Pn1J)osy;W28a<&OaBo>}AbaFsm zTQM}vG8m^-?tsQ9%F9FHdB*b=wC4o-Z_7|Oabeg*eDNZ@ z^4G6%FV$vfXq3bxPF7Ir@QH0JoKr%P2R>25YVZpPND~&E(FsJor}Cu-&UT(g&$F4s zgel@On0iZttsHj5- z9UU`9oD=PMq#S3XVlxTsTN;)q0)OQlzP7()7fE?W&o%dXbt<)nw~7_5|G2e};Rl7p zg1DDkPBOk)X7ZGbGM3C02HA?*ud@15?UT1NuKLQZBEU@n%8G7~P}MPs4y&eHL?+0Puk`gm(mt0~uU9&g5T@nykf5 zKJd_C5fP<`vftFJ%$_Zec6HSNym&`|!+6a9HM(M%1P5X3!cMa$g&LHRQKz3S-n(PG z%R(f}h&-kttl>CA44=`i6^)&cFyn3xMfHdR2Uu~0Uwf9%fAyR%k)#Sjni_+{u@j67 zxGIrEBC(BOwF`1g0bb)Ahien*f2=n~TXE$Llo}Cu8z&4ndi=N`Ti)dTLdc~xykj*> zU91#tX;{Di@BdiR1xbO#?F9cI3Q(}ma7<8i= zLKb!2eJY2X8mx^*J_oTd_^0Roif;sb=kjRK{#^r8f z0VnA3#EA_Gx_Ri`m1E$gL>jwL_R=tLFi<`|-FyZX$<>3xio1^}z3@Y|N*0MW1+zu+ z1+DQ}{3ibm5TX~GFhH77NL+IV&mc^otf{9N8I|`sk6}2ZyZsGbMD@i0RveqVp08{_ zT5{=sZl^&%B69IWWGP@b`K8Y~mErp`mrb~Ob)I8A0^n%cD^boFSYYlp!5>2hzJ0!X zZB#CZdgsbQdvqc|ny*Cm;rW+@*h-T8!`E=&O#swfVm=3K)^B_TaD;9u1G~V=cv}?& zv@~zf8L2~Ce|Lc|?OR-Z=2Bwf5O#uCDWa!DDJr%I=8cU!YLoi6blA+p{;xkn(oEGn z!}{ayvcBtS*xB@$4cnv8fp7YX=98H7&qI&e-FHxBuD~5VM|DhQrtp7?TsvXG zSOs>`0enSyq$|JA9Ri5xM!a}kg{BEKmt_1uq6Gh)<+HsNvzP;nm*{Bl7IE=yp9qsY zE88#s?fe>|=Gt^9Ng9xJwOTvEM8y74uJLu5vIs=1nu!lihl}J6tXvYzm8^hl|6?2# zLR6-7!MvJ9?*t{|#+@K*xdgC(wm}Y+PF$0Vzh}e#LF9BknKHtKU<ZX1mtx@jw}x|#nfu9)OrH@WpOVPA~@Rp>Iom$OJm#*4hORT(Wy?*~P; zM6@yKq|>kpcvo1!=&^>v+|#!MM03uOyy(kmtyfdH=Cz+m>ratAu}w7*n>Cf@2+a;o z1mmL7W5-6E4=MintMYv=n}sqdbnz6S>c7ZeIJtTc9{l3?_(yt7b<)OBYvN8Jz*@QQ zRP*4*5wtl1^$ec9AFX}~l0R~QE+V*Kc0=Q@$~?>i<6xS;yTZG6dLz-u0&yU>ioP;w zBmX@jz<@@AICx5;6d?DdFEBjPu{;g^r%L^Dls>5U-+N zx3BXM0_rq0#cKzOiz2rYv^raV0Xiq{j>kG)X~zyWByl=ucg9SGyLYnzKp&G`ax*99Mg zETJb*E_8qpYAE!-(!nI{h`r}|iA5cguKSnwd^vR!;T-t`C=XcyPV`__xIs(zTrFg2a$_{yk^pt&+-0G9ygk z5a{AW#5uRqK+F^iEgzX=;-)`}v{b0^$fl1!4K;KQ+p}66P+RcK(~vQOvhwY)M{nX$ z89!RW<-z{U2;#&NZ3*IeEZ~(})P9if05J5JVIqrIhcgh#L{_9PjA<$`VFF=W~>D?p9L8b%YOLB0SCeB!NIHFC2A;fviyi2v^|D;0Z z{O3N}_3H&%k42QQp_hP5M*#^@t}{QfMUbeAdoR$oeZ@e{xB(%-^_T&mSb=xZdt)}~ zAvz+f!}#KWq_voY=D6dYkTVWNw0%Nt}Cb^U`0aRCo2X64snpPGt1w` z8SDvrX9^Apm~PhIbw1vpb819zprUKPKb_Z ziRQC-7X&1*uWmUE3g;TVSnV+G<1Qk-I|_&Y|%bHS8l;E`Rc#v zKq7~f^ZW0GK3zK_a1?0w{rl4dk|9tDJ`Cj&83P-*M6Q*8#-2K)9io62mW2QaJFI*I zg@#oE_1TPOBH+g}gim-)gil2nlfkP^`Sv0~MePtg1WVvL%0q-E-a$bp5opqbiO>kN zb(E1|_Dyur;tdjM%NG_>Q9OZ!NHW@qj2Mvu%FKY%ViN*r;CzDm(Y6SxhrV5$SWx?7 z<6OVaF76?wUYMe)gqj{quD;W0M}=CIzOMPFjq4K}k<0*sUf9s~Zv`d7f#-%w?SyoZ z?)>>hrKLImPrSLg*K=?V3N_%Ex4p)Uc6D{7w>Ya|KMM81nZwE=7zw4&M&Kzy?cQR)4{uu`yp4!{uW|7KQY`d9t9 zki!|FR}q3aHU)<;sw|;)LEEM^wunS9s~xeQr<$LxKfk@AMpfME%5M- zZLKW=ZRD+?(;^>mBQw*RTW0i=zWHO18MQ7401#*ng`{MbR9S6AJ@H;C4CCT{6Z-iZ z@4i`Qgwj!EtVP-hUMa;WMGs725H`e%DvY4Y9bZBKUccyQa7UBz{sWhXB9amt7(2Q! z)y6+sv3xm7|IxiOL_q~1@SU8yw4k@V-k~Md>T>*?l}n7HhD*Kt4^)Qnp*1dKYWwO9 z?vWnPBKxY&O5}Ed+J>p{hCo2t1OXeGOAaxvxT=2qW$Xd(J|eO=aJro;VLjYZVBDG* z5|aGf-6;Df-Y0w=v)K2i@2HYD!ZTf%PKy5-kr4?sMVAadq^tLEr1O9|YUP9%^G#rgp@ zfbeTY_sBT9JCIABPJB35g*M^x z<;$ereBYC38T;vGcYtGxqOW`uS&)DJ6pC~RgxUm2;rD&PMt-RBj)U+Sr#V8Wa(<}I zP*K3nD^FgbPY|K}b-Rl{S=H|*EJkucizwUU@j;1{ngEv=<}u_=fu(#;-B&ToCA#D z9h*^ox-u|-Omvz`wE9No0*LxQZx)Q7=+f-KJ|^O}gf30oSB>35Z(z5Inb}oD+tBLy zojKzHf)nuev~k#viPIgdI+UCl?3nhqn|>gUcy7C!XJ&R;x{qM=C=@6VMh8bgyD~Z{ zGR251y*XnCp18lSMpelCMdq3LY$?gWqk@Q;AIRala{qer{Fo-ZzbLLNoj}AHh2IpB0J208 z&A~O@Z^ph>`bUWQcVtfjO1$FQ)_TWs?N{DBCCUwbkiK6j9h#|J;CctqU!0hZkWSolLRJH68lRdZC?+-n z2zt%4VP4O=S{&9s2!$kiC34$<4n&Gh?|IzgBbX>L2j#!F<}-`SlP|7m&kh*Z7&THi1r!( zSVw{K50Do29|*8pEyuYIiRUf$ADS)BdYB1w@Lve~NdUqyykeq^_*Be8XLZc|Rlu3X z-??j99Eofx?pu}io8>2L+AC`A8;x3k)5~`w*JeIOnq5&*ArKprF;U3$M9;%r1m<^^ z{YVn3bZ{74J_^&Ot^5`7a?2$A{iPg$*Kgl8TJHuE<1gpL`()yY{*(vrdz(>LK!=U1 zurY$lN>DkO$tk92P1!sqccFncZkN23OoTy)}(3k#64VILJ!0(4_{HFBW(2`sfV*kb<`+>iy}H+ z7>H>JUoXCQ-22d@d@IV@gM2qrtIu6geSaLdZdFNK(|VsveVS&`E~I?cTTu7ET7Z-k z<>H&8nOEPfdQL>s#fzurUnxYX>&(3jf=lF9ZTZ}KVKl;e8E+%2(`nu7<>!pownfOs zcEAIvI^TNXum6#vI{r!Oe2hV-?L@ad<6DREk`>s$Om4uzNP%NNg<}VpKqfIbQgc(p zAUUKlG+)3A9AgUxc0zYUfr1vAWqkQ9 zaQso;6i1BEd;S$ml!-yZ=dQ=o4$&~Gs!HJ30^DK3tuV>RgDuCq^pfYkkYEOp>M|a$ z<&cwUhuu>l00zQliXIeTWq|aNJh^$R$Ml1i2eQE*QF%(bJw*xcHTVN%fFXMG?O%Ij zWJeTDzX#dJJVjPro_u@<)|-nzwu@qXMI;NX0f#|%P+xk|HAQ$cpO2vKP=He1K?n#ey`MyUBR@41@fo3+ire2hYtGsJq$lXM+atV?=TcGQmIJTnmzbv3lP&g|KWnT< zr-kH;PKi$ulis8rXee3^Z*MKAaTt!56&$Ai!HT11`1$%W z-X;Tzpq`vwKY8lZHJyXpSIlQU(2jzcWIT!0yGq<&e}qP;Xy}IU>ncY<{|JLH)(#pU zqh-sK?O&!PY>!?Ec*cDv{vcUlYsP;vX=q2G=5!#nI(fX>r@kI7fG99O1V0EVI2UIK z*&s3l90y3GH}Na&gY5iL-jr{e@ro70tF|(C(wzzt2nP)QNKQFWAJX1Hi&nnUJ=9k) zyHuM8Dud>=Qy_Zs`Lr+hcRd+(-)S^c064N}o{J>*ACHlN z4v7V`8OFPUpR+Dkk>1Va9>QtFy75*!66fYUSGq?JetHo{yWjQ~^+@^SSO_965%^N) z75G{ugfisVYt1g%!QIrHESRzzIF2rxGRJ%+Gz7ZM8W(h`wS19nx%ZCCX z?CN9cHsW_<&Wl&2#EARFNWU~USdZ);&f9d;rkov`iXV1?Rw%2f$ug24K1JJA2EQs2 z7dRkp{ZXDdNf6F-Gn}^qpyv@GGUJ0(+tlh5|Dx+xFs(sx_!0#eC|k7eSj(TFPm7)F(GGBX4(Cc&C#R;oN7rw4I$d$y%oP@V zAc|qYfj1q0u_tX`CdMw7^6Ecv=spLBOBg)>?CIcaU7f5ZE~BK=w%8V+wwf>Mw)JO) z;lp(eb{yN9_x}A6s6e=Ac5iPaipMs!S!pWSBUP3Ciu46nae`ZJp~$)GVt8X@(a2UB=8 zq{DE+GMnq)D-4>QW1;@sJ`BbRS&W`nLf#y`DV+)9MU?;;Ug98|RZj`N6D1TJur#0@ zpsjm<=xTRZpga^3l_@_pElu3Kz!3oAjt>;vITJSp*RV29kBE*yC*>kU3+j_2NVh&)9v(0 z_z8RpC@8~9>;G5$(oo?3h1+QrJ9U2d`gL^r<2_54`h_1&$}a<}m@84)sd#RL2{LHs zyeH8UeES{Qqf`<8&Qe8Z%~`#kc+9di+b$3@C`Tv38t48j`a%3gRRgs8bm`-kd>|-j zo!1>b8d-)?kuJtm>Hvtx&~krM>;p*&R~#J_4`BvHaTIHPE-4OwQ8t36k&pKtq7L^x z$bWFHdvLh>E@o55Q$;(i{^PZ6%rH%h;GlS;$!0Iy+Ia?a@wOd1I&1WN%o~o2=+s-r zV{PJkub*w3s$M?%ar35o%idtSUo)5IpEC})d3BWCD0795itPt3m+75y!uNne@U!6! z6YuWsrrEgqAzvo1Utc+o4+jn$)jM$Hz=1iR+8vbElJ%7gyngfT&Rf>`e+-+Nir-IP z7JAdyyN`C$jH()UkMj4k%75+pnV9>@Dnje)?YNs4Tz>xeu{%$Ba1Tn}o)2mo$85?2 zlOI)$aPYbstiEuVtf%xp`Rh?$R~H@Gn4ObTtomiQo~obwlS0M3#kuyzVg57fn_V;+ zraR6I`u9G}PS_rtvpU1soLv^%4xoPx?~L2FAvTgz>` zO%ujs!VY<7mFLghpe$Qi9=++@wex2iOfT$mx-FZXd~e6p5LjM?2Tcxz!k`;XP%#9B3ogF+pXwjwtM@H)&k^6pc zhhK)d(Q^9(h1vIZ-1#`>K_ow17l-B3?K1Rz!$x%VoR_k2w&Tux!umptER1wq7ccg* zx-|3j%IsI4?fV3s-ppMWA|}@%>d{owp4PN6S}RtZX>y)6SXD+c``x?JBieS~%HC## zMfE=SLaOiH8xN+by_i@#WYO%}u^UDO*ybt>5Bz6TGvv#e;zbwgo(#;tFfW@thZeI7 zLx07M$y7R%>0=Q3IjhIiVE#{VRt~{QHZE4@${Z7hjA^?tF)t^_ZAV^e@varB<}s@d zC4W7jrePa(>sGfM_59X9jm{!pKd#HtrvT8f*`H%oiJCMqnNo6b1koZR`mFtEclC>5 znpyOQqcS>6gnFS**X|!NsQ%0=^PE@Y?R4eQrtvb8(yObIv+aW}UX)+9+`jjsp;D(3 zPRvv4c%s2AfEVh$==y8>#0{F$L)J;G={aQB8NbD^Lc+thYq?x{WczdHwCWnu;(4lm zm2tlYTL;Z{u;~;U`ye5~S?|g9VVV3;{{qU5?6Q5@_6Hw)^J(z$G>IWGr?h7W7{q@+ zUb7~{+*Ygcg549J&oB0r9q=1px^qY5%~787%(U_^*ne4N7gBQ2@kZmj6&GjNj*<5o zYWqWDhtHzx`TOAW@+VIFU^9Kw`13C>Zgsa@sFHV zj`(RK>zR5WI=9}?ZUn#IuH%`UI2x6f%hJ>A_THGX_r3j-36h7kBL|bl^Wok?b-5wC zZAxpS{1Tdtiq-ux2H#txJL=Yu#*&)Ge?JMW{9dSCYk6+|w86ioS`X>CvEYy6{rHIk zUODf2`rfp5-IWV=o?J@d;gO!+J*!>O#O)OiemNdMcoHo3%NpzW&02&17AgfTdT5aF zsPI~k!*+AC3N891$jiGPA8;fYW&go9?Q7ghY^kgN*RO_YMWyZj445*QPoi3^YCq%i z%j8WPZ#blSXfI!$RZ&5jtGf?M<^RX8 zKe?$j;+we3SRc~hxk$L_4z~I9({BirB}us z_|GHx|Nlq&&fYRY)fd{> zR2CNQ(vy(WQhs)5-X4P=b@go-i|}st7XmqXejfEXd;V^)`?<$+%rmbC7U}N(zLAmR z%%TUlBn1N^dOgsSA3E{cHsN^|wzl$`iU+Mr`e!^TcAd%ZV$TLCI=t=NlYiDDi5ng( zE>_+?<+y)aOs)#!Tu-m0v4%OyXEJ+vrN6e{{~>pd_^lzS!0}%SG7Mzo+|RW%%(qZ* zOnTcC{P}^>KbLMjURmv^Zyi*KTse7N`0wK?5Y++vYZ`1|+ricOCZKz?o;I@$8A z?n+;;si!w9x_`y3z7?602<;DUE*XwDqOhiM_P3(_6;wnFiW4dd&^Tn@J-Ds$->YOd zc;sR7UOlY_dZef?WNPZ7wS0Ns5TlBy>WA_ky1DkiLDRs5!*1Rz0znhkg^qY0ZPR?! z|HjSdBn%u)FD8oioCsLnIn6oMD%pceFs}e-9-y!A3Xt|O_B9A4BeES6LeB-R(tp>? z%{XeIZlD&}Y~+&r~5 z#2ESLpXC;qNdjWhdb<4XxO<)M@^|eE?CaFJ)a+WR%e_E)*K^!jBSJz7hi%w9Yi4V| z3fE7*mm2yO1xw3bESx*yz>i6t?%($!ABDc8Hj-b5in!|KWG4(Z+kiez?u@>$2HhE5bNi zsq3EPr-u6Bv!*mmyL%^BqOku6{q|i~?x_1cIz;8!%n?ft`pTHMMdzO%bJacP%)fmK$_kP=kFlV>-l{dOd-0AYR zSBKcubB>Q(IdG6&@X4Ulq4B$ZuZ(lbTrte(R_E1*X;V~8&YXhTCf@Lj7{8)sxI4SL3>-6_AR2BDlLOzr~MAVRe`K4jb)tH#xnRn3arQViT?{>HY z?5R1Cagl!>2QJ`BIU+BdA~VOvF@Ziq&YcNXFWKpc)zZRUcUPW1Tu@bxT(=AX4Q{UUT(}^cqyDGKW9!lf zL}&2?rE3Yxxe?zEkX|dkwH*5GU3ph zgxDSD7mQEzM=FscF3c)uYpo}EW8+kQSRdo+lQzz7s1@TwR73qO&OH@)c9m;elUL%M zoF~OLeKYUqDa$(ly|rO8Xm2wqVx{rr7YAKRZM&T*)Aeve1KGdM8!9!)coRHR`PS^$I%p~2xqItZnf^Dnz#{2u2ffV zs=#ad%#Zfb?x*!!a^V}WL5OECS>8EboCzcueI(tG{i%4{18Y!p=* z%Dw)pGtX31XS788MHH>oWA_-E?j7`pXW|wP$DT=`4d-7uez#EhJ8RsaXzDw)g-d+D zr2=VoF6iny3KVakvM$#d9hP_7#iU`IZXZ5SV5lw5$a?#`)*}O!J2ZM{%5+LD{_r>$ zUnlQ3ATazL!)+>LI2(#o8 zQj!h6l4vHmS0^m(TCLxXysx7oPBQCdSpup{ius0;X|+?i)VAei1g|UMZ+~W%V=3A; z(_p1jA~LPKUFh{5I=1v})twmvYWx=-mkV_pYZ}Fj7E5|vS^T5AS7lql-l0F2@)U0@ z(HX~-GQWGwr%D?zByL?@$LYf)HWnTbns`?&9pqpkT=f^ym^qe3-yd?Op`7E+|2 zGF+1uGATIXm8RP8fm`F-zHK{fHc}lk9@@D+?C^$+#A|Idl6&^e_1JVy^OCsvdLHoo*o>>dh_!b;-&W@%FOk{b`(RL{O!^`dRcIr`y z@yKn&@B#QeD|UvQ{CoyuW1ZX%$n(;O)6+#Tx1g07nqP`DLK8w#ebQGZ?e{STft z%EYfm=V6g+D*m#bFypu{3oC`ly%5)zAKBcZFd>ffy0^80`mGAu%5JC6QLVil9d(#p zD{=JGM8{u&+c_!*>~3Lae>R{_1Li;on&{&$KVu05P%{HkY6GSN|hc*wO`>s|rJj9v`PHd*i7wL|4z?j2X`_pkTO@%bxffZC{pdN~wJ#D;s zvqi}X6HzjrhT-wzc5^L-0|$`qq+9hz-ZQ3Qm@r*TllDJ&_88qtoYX9%5zWCp;~xEK z?#gVQvnL)m#ShR$ex{S&RdM#3(1-h*?`pErtxMFuX2q+Z`*|={A^6VaGx{jy?%{!$T?>bTSzND|U3te(`{Joc#6WsJ>%sPzB17RB zE$VGyl84tH&(qV0;P+9#)g?;#l$F6Dqb%r-L3QkavPfbW{bbNsE&oVz2z{iHF7F1m zr(#UEGplOX9GdL{_hcIVUZlIGBc_}2EImi1fM$Q5;Yp5r+f#*??0Hy4uh0IzDSG@A zmw<10RgROy)jlR&`Gp3Cr(Y;^wcCCth?kw?E0r{`6K?lWjZY5#9VO;--_(MiD&rG< z)=0od$?J1kC1Y3i^~q#W$6Iblkdq2!qts{$y)PpYthHdUi+$rap*K8-bR2!2cJa@v z%r5hXxd#U4F2v^3@H~Al=0P16#H*t-WWPS|l~A}Dr3gbwxisxlvrv;IgH5aoyzdMz zHOR>vy~!JD!J09&OlMX~A)tF&dn9nc>6*#5LksK1N3MUErhCk+{Z@I)4j)I(Qf`$q z^H;gVeU@G)s3r!hjnfiKa5 zT`<5dw%U{_PuoWJ-l8SV#>zA2135lDSp90a9K}_8;*T$z(lrGg)opo!Vq^Yt5l>_K zmyY}xIDh=XKmWvm&5vxlp3;n|Xf_>q|21c*$i&h7&YKoKi=~zK)_cRSc#zl+r}HjixS9A zomWI9r0v8U5YzycjW>$LBM9wPZa9+ihR1qMLm&sstMgTaF4%o_w~!n2-xuE>q-aE;_?sbb!sX^3JF9-(o^#je|)L>mMC9~drw7=QDk ztBpS8_I*fe|4yExH!3A(hC?p&pzQ*jEdTf8F;R{P!rXLzpx#_U(r)r#b9~|kCP!rY z4Yvv{%nk}n9$A5mMg7W}`Wgi^#EOu|afV#G&_h%;pHJGq@XL>Jd3}_bO;KN+BV(e8YWx|4_Bu8xH6adKx=3RYWvRoi7F>?Wn^RVOZ!Y9f zJ{4PMeM58fiBGs6qg98xkP>^ftXZrnCx6u`twX}5EO(~1&*kx-Q_FnbE%^5a_ZyQ0 zl}Cr(*w-y^v0Lq2C*oh9F&d?|#$v|SsB0xf?eK}=TgJ>FW0u|XNIn*k&IilBiEU1N zwU!J%pNT`q}G`&%2crPF1aY?!3-SXIz+PS?*koPExf=Bvo{ zv1N)EpJ)%|UN@*daa^QIB=QbT!8(QXcd8(*?mYG)`>mru z;2Eb<#}yyug`52~eb0VnyrEamw*P-F06SYcL&hhOuD*MZDL+ts+I;+q@}UI{WnF{7 z-Al6b{6C$4CLF)K{qUJrA>!J;t2rH|k}({HqHD)P2VBL9J667Rj(?LG)4J%PtG*n` zpE<6Sv%9IvF5Xu@=JvB+n#S?zVJx>Euqba=aZc0s)pX5Bj!8&=O+PX+@t()cWt(1q zYiCEq(!l~DzrDT<`r3(+gC)`uy}dhwUTrG)_EXA1f@`O2Ld}v-(lVnGty6#N6}9xQ zInR3a7mpjB3sO?hbX3XEsXLL!f76I(`IvO#)j(rzp`x3|zV|B%IQMgP_XGruFpuv} zclpyS_t)(Y-y*F~eT`PrEs1Nq`yLLm-Al~raFlrUOO4QUi%kjTXNVY*+;@yWyyFGL z!I&Z2faXhjAmS$hbpz98Q0W-)PN#=y-b4Hd|^2bZWYP+Au`#T>Qw`_7Hb`21N^yryy$UvNlE zeTkj9b+k_U2q5|wr=kr+q-|38n38M_#X0&OpM!&A_R_OVtKkDL&H+2>koK_Hnb3ra z3W+|!YMH7tpXKC!q7Q7Bj?xxVL*x0&9FihVT7dmw8boqM1U7uM%f(>r za^A0Y`z!JNH@p~|nxf_BhybkrdHNZ{b*i1IE>xr-1iJ$wqGCLKbuMA<7s<^dXgg6F z3&zm}R9(@2k(Doo@_BzC2mmf<^M9vJ`;R_8&GfRcVhy|L%?Pd^TFtuJdGRW~+ftm98;noYT(-glSRX@&8++KNqu zwTstVo1VE0>tnh)O6gjLiK+(&{(T-da9bRB2{S;1GpN!12fco^NVpNN*YfXsWY{Ie z{^`&ETm~aEon(5BCh<9^$s9?*AZAB`DxE5r@=EC!+hja#_!8@r@OL`^3k%)UMTFqh z`56`Y8A^Um#*JnhEjlfHe1F~g(Q~xrRn(#EpwCTPcV50YbztUG)0V+^*D~GCt?YNH zk9%^T$z(Lh^_yF%HYc#VnlQKs()$OIdNqKdbH*G?(ewv4ccV}4U4-c0<-23~E8%@t2H|CiY|M4ye z>!R3nQZ_v#q4FnI7m)p_lRsUSn6=8}?%HaVJfrFT`>E^#45oS&rG8&Z)yVRA@+Iz* zdX#-G-@mDB=GF3JN+LfOL(Hmp4pP3V`YdPP-qND)eforn3GcD~wo2=x#>Qv^b8|=c z(z4WKvd3k~KCMro{Ugj{Y)ltvM1RD%_?XQ2y>*i+qF;kTc~(MBd>JS(p^N^cqWfVF zyZo=Bts=?%G)s4k~J`QEmv zdkDpY1xwZ;dzx77@adDAKbr=RF17s#ndTSV{Y&-%&5pz!R^Oup*H^X&i6>;xnWg_! z)xQ2Zi(=J~!7L%LvQr~$7tN2$p|RCYH=iySN!xV=K2W*x=JglBt-o9Mp`KaOU^fY^ zxwyzz=pQ1a*cwx$ulghYqq~nnz%9+=bbE!D(k1iOVkFh1_Aeh1)4g8(WshO7+NDb7 zQlm5QN?vje>D{G}#l2CnGiV`Mq*#h7ICrt}U0dTt>3^d=k5$9yYVB!$8Ru{IPC-8w zJ1eWvM2tu-GbWu`Bv`>Y(A~2<5C!6a!=4bzA$&@>6%=Gk!Xg4 z(WTm)xrTkU?ko2cZf=;QE~y&tpm}5>Al)3p#qmw{_Sm-flq1j@@J(;udq`6xaL&?Z z+@C)p{yt#B7`N*InHceL0nD=DG$01x8Dk?L76~acp>GQq%b$E0KP0P1l9Yi=_rx zR9{WmJ=*4P3sYaewC0n^@bug|VZprzENoew^RGO|UV2(JjKN0VYVzqmW>c07 zOV0iK?`Ig#jc=l_57=;kdiBt{^qFYE&1+UxWplqnn$*u^F+X}KeH7JICd=LI_L&c= zv@DdBtYW$bXLv;_zMd$`GG;KSIal27Z7N)pq4cJmjp5w&z)z0WW!( z=zy@J8QSip3$xIr5djpkHi;7=QNV+2kK>T2?;sW6lg`WBWacnC)(&^oyCAgVYX5b02(7C@>0Y z!0v2+LiN*@<`Q`H+LE8&zuyoR>ww^)Y`6XU;}d*8%732ekwF;)JqwOv@R4`iKJ50z zF&7{@fN)}Vv#WLE>E%+v?%v^@rF~K=TUf87c-M zMFOt`suxUusOyhG;{tZ*Ah_uF^YbAWs)lOsWVzSw{QQ^o)VVx_QV+XMHp0h8Pe%QvWqQ8+g+6d)PYr>b)O~{=FJSJb z^-#r2<{tn#MX-)RFu*4;g&QMG{+v(!th~_K*LOsSZE#ci0u3k$naX<+K6`o7Rccpbzl}BQA#JnJPZhp@;+X@G!Py5=eNLscegzX z2D%zV;1Z7{PDV0%?GISG(D{7q9kl=rYy@|5L?BPjM)DQxv#`40p!^;y3RNLUmBcFj z#I=cU0R$~9HYNV&#M+of}`WsGjfQ=rT4HfiZ6M&Zk^*B%r7+a@6>l5G^Z)9X>s0#HR zSPRP|d#(^mTw*9-Dt-u7NrXs@IW@Hnj(3etPl(KPL4Z?Ak0CsM+!B~@F%l5=?_0x` zqlo`Lv$qimTR>7*-l9x_{sB)N>YV7piK(wlbr}T(5oqZ_Y@)wWhcXuypRYlLgD)hv zhav_nSSS{IK9QhiB0PYw!0`0x7O)mTgZi1gb9c)7p_#>7`_FTLE{u+ZEDd`6gC_f5|H)tf2a zh>ecsmYi#rAN`xAqoqJ@iNArp^3qB8RP|N#Lf3(U{bs?#(;yvY@Sgl%o6-(2@jxOq zSL!K)B`Y7oUhQFEK>{CvK!f%O3+G+vaW0&1E&(884;@OOe}WN_D);+N>?R;C6wTVw zTyOR=lC*BzG-gL_KIfc7@Jq0LVI(IAec-M8O>jc7xITbRiHK1lm$rp^4y{+lgaU?r z--*T-`&iHvdulTe1TPs_2nHe=cydHO4)~h-Tfc#vAD2CEu8}Atoqa|Acx$m+k?jnG zM^!X~Y2N+NcrZbB`(owWWF9u~=K&MFh;Aa+fSe6mnDq!H=y+I+K6=(agzpyM-q_LA z6`G$9onpu(Mg3w^8_=j>;L>B$`@%qWS z=$tfedDbuuJjoCL) zS)!YK4DSyZDXAPe^8Ttk2CCSyr+nI`=TT(y-n~3`lK7}scfax9MNAQ50)GZkQsA_jE!b(R?ZyNKhj4&JNz&;D(SdhX6!}@ z`xt(&JA9>PXUe|Vd}bH;SP;h@9bB&4)OKdFc-*H{W&R>V@cpH*CUwTh)3y(~B}8B4 zb}T;O&N?h0?Y&^l9nPPzA@ZQiV(E^LNq^Y+RoXZIDPUq7UB^=S;km2x@h?9@*!$R^|>m~LYF)jDPopj->SpL@8g3TvIP`k+q<)*ZIK%lPiC^O9uhKEtP)1%hXeNmzH;Oenru{o&cia+<=v4@s%KdYzJ3db3yKtsF zW&8$p5mrM6YX^1IZ78cuzs+$IM(Vtl=!ro|} zVoA+ly)~IVdOd4fA=9>xp=u-7S)ax+ypUzoR(aligUeXgTG%$>viZ&x%XJTV%Y||e z1Bn%M3s0{-v?;qs(Ek;)jbukR7xa%Y9#0cwoxm zaq)-z`WyJVvr5(%tKvr3jXj6OHqK-saz%~aa4=F;=l2?j7Jm}bmtP+o!=CUT%c%A4 zd!p7>s;~Io7)K@T;3 z=FEfYU?;XO*NJbJdi~d=1XXt=WR{lR#O~w@6PS+f?y#@Zxgd%h4fvCA z%KiO*Qa&S$yzu1@F;wj8;vCtzxi^SP8^3`VwDEaP`f90q;GUH@vO9DpLlIks~}2?h_CfBl!_M-gjO z60qVMFq_TyV)#U-C%#*Gsk@aFKWIr%^}Z}GCv^v8e4|(ae=}q+tZ|`S2C}NWGJcdK z0HQ2FJBFz&>1xPi%Jij&qjKGG7^K2DXlR4wLSTovk`~lPQb#6*I5dz1kL95)RW28) z56CiuBxFNwJG8SFMTwv)kn5smR;-3;YVRJU%EAfco;-nR5g^5&nM{TlKN%sU(p)4L$B0Mb_>icYSqL&)BzHV4x56FfiNZC9#h>=qG~)w(x^1q=LkKad}|9GNAAB^HJ{2^H3iv z5^G%iQtZV!iNwzhJhDIOYm4cmaUwS-`G9lzPnr0pobpy*j6qL=(?Rw*aP(Z++u;m5 zjNR-zEXGIi_E7<&z49;j$`I^5-43>KXMcNE`FA&l(Iky@;mw|aYmhC7SE3>#+!@ri zC|u*3N66zqE}SIBFrKz=&^Nr8ENLd$u-8 zLRga@j+W(iIHJC7VUq0~5Hp030<}K56~rZ#Jlq@m1LY=K_IIADgVh%u4HX&ll5Jq> zZa%i#iK42zaL^&HX>4H{%GB@wi!939jE)j7Goiv%{<&F{;dp1;haW_t4+Rw23S5Wl zG1CaKJt=PSGLbkc68<82rO*it%rB9g6jaC<%g^}#IF|myJ8zIoxX{U9s6hBFu1&(p z&t>~~My4UpUijBprSy>ZMoZ-)c?`e;yxE6{dk&VeKMYdrfi7uHrX5I?<28BK>$br@ z)s+@*4*bC4?JL=LnYchy{gU=ImM>KF?@m( z<%nH-c&4xJxje2%K3x{J7U&PSBw)_^9wuuyzD^S)gSc*hDuo2H#$4qU$GcZneTX?w zr(bK?1=QIjQbP-G9I=rlsJ_M49(Wvs;P-B)@}dsN2z?5kz>AP&=Y}nEc1Ez%GYu~u zIg6EkhaG)>-VFObMgjsa3*CW1wF7+HYM(mDRh!z(1s;<;&?sv}ZA{%M0}@C4zwpQJ zs#b(6(Y0(Hxu4=}1zI9L;)jC;5RW}U!rIjH;f*+imVDGX7}g=Jy z31&P#ID`q=k}T&34<5L*#*-}o_s6|62EEK5*pkJbFMJUKquz2K73iPeVR7gq!~!OX zQg|q%AVUqGsQPEDd=y@bbH6L5FUR0}|2so*ED|poNRH5qflNuxbihbIL-yl^kVWe| zivE2of!w$^mfNqQ)k}jmQ@mMCm-N$E_uT)zmbv<#qf^BCk5I&zd*{k749_e+Ul@u2 z!OER@2Li|-XLv|RNQ~mj1gpUDy?)|fc$+ECI6eC(8Y?tVWT?w9%S->Je_P1vho4rC zS1bEoPZh_pX)stts3lskxK*5b8XBmb=$t) zRxIf6ZLArVR_nM-ry2Y8;W1N-xAX!Y*N#*E9aFzXziu*1Xz5OLyKUT0W0sK8x>p9* zmuV9~BJ94Y8M(qSc&GaDS7{D0&Yv7VkJ=g3QeD<7`q23zs{f=>sGzW<lU;ZVr*g$i8?+@@i3=4 zI8byfE=l9Nulnk(d&?ZU4{i8GUm5E5OFS}qp1+`}YL&h*QLXD1mF4a5vGW;m{C@n3 z-7$1Q7qk;4>5lx40~4t=cDq={<2M>#&&MsgQ@Y1ozqp)vu_%<{n=I#8imtOwl5?`! z^f%#eDm?dw(lZ2(gng3GJNx|K=`2}LfLOSusgGEh4-_1ui=K(I^obQ{ptZaiZm{n9 zZrTz(LG^oHbo##x15_Vuo;#xJKfaxd6erU?mtAD>#)DFgbNdjEID*c|D3={K|1#nWOm21;JdNdbR}0 zuVe9p4L^!S3#IgUPs{)6JfCR6Ehwi$ZJtHL{!w?`9FJzIm*Ee-0~fX*U$uw@8m_}!uAz(i%47A~&+Fe{47h(+MdUz&K6raIBK9IJa5Bnv#LhNT1Fb6txQ zt2AS;=n#w2I~i|X87H7HjXs9xf6{=^Zv<%J%Z`MJ_f`l5h|VjB8b1W*d2=9QgS z{Th@}IH8ekNUYRgWR$Suh#PZl$XOKLhoFFl-$wOb#?{|^YjM~TgINe{1@AGiA&&~N zxmwI3*o(N z{(fdAWrs8ofP`#w9sn8&{3>2Qw`s%;nvNNq;O$Ju+{NWm_G@)#Io>HBoERT>?Qwv! z@Yu_ixp_ zfdfJ#c?Ks~7dmTy)A~XsOL$|3o3A650)P}CIPz3*&KIKzA$FQ*&v5t?Iy)>UCve&9 zA|slH#y4>UboBNrIyL=B>-xNgU{&PL@Sr)sAK0vU78vZ|yNb6V9VI;gkvYa10KW$P zUMdv_(bQm3d3n#H4j7P;rVT}c+s#}04wM2k09$dqvDNK8m2E7M>|WO-SD)me@E)x( z-Hp0b6!3(RhHB$EZ6CA*QZT;8p^t7H3_`JoXfV_F!gd-8Rru?ilF-6+qbJY7b-TC7 zHyD3@kpCkj`q$D5@ZmUZA^b*XXZU#q6;2ciHpewbzq*B-P3Xh0z;V=vNE?Ci^1U^^ z6>D+G_XgG`fPIpW0S7tY-bvyoCC5~c(Ot{Wcj|k26XaYuZaa)5NVa+*>(Icyv0hq_ z?;qhcM%q9Cft)7PutkrWupR1s%8Q{E5NzCV(gG2Y1W@jF&oubSlDdYtX?L1emt7z> zZHV5$;r+v{O}HU~K-u^A34Vw)09c;@giqZox9&{= zfdb&a9qJeeBgKO82f{Q0(e4~GG#uo5h2(zxg*`Ie!q(ZAk1k>_VU{BjLIDU1*@Rza^H- zgPTAs2uIVeD4`a;wdAr|f9k96t1p0zVCogqta0F&lw;Z|I*@pnS_rYRj%MQ$j7&|nkwcGj zip3*uO;oDpH9WqqVuH0Ly%p`j7I!HaZD6>&iIJA%?K*T8@(j9|_EI{@Y=dJJ{KAhK znAzIdeJHdz@X}*k(1ZK^WHlg}Tjbz`BTDVyUcB* zsyL0a5Gqahez^brbOL}QNnbd4&>|7ZCd7dg@6lG}0L6_rf`B>_*ap7fX$qc#@YqRi zeuIXcUnN0m95lJt=z-s&--ko*Mj#qEK8M|2kjN*DuyE@OpARwoM$;tJA|8ZnVG!Jo z9Ybr&dbuZ4eLY@lP>b0pRlYC(xul4zj7v>@N5o^;k6u@NV2pj*?SzvpRevY2z7aq- z19i;Gn?1b0-}c0tY-4LnZUTY>V{LC^Q^Mq^9zTpoMS*}SOPk-bQ++A8QE@Pip)j%N z9G<%aGh9}h6L?}62fnuGw|kQ28uHLb0ueyYtYcw8C_n*Bc0aP!fAHBzxIWMz#`-5O{0R${o6J^{5D}Bbb3rfU*;+! zg7dncA9r?tq8G>8+A- ze=eln?ef{+z2>;9(>(IZ(@_0O2~_7Io_v_hjg6jBJka0n<-sm!%TQb)7O$>$Lg2A{ zNV#@RceC!l3%qwC9*QuDeLlYRtIHOX^urE%k!?RjQbp6uHXJc>F+OY~K_4qf*RQkj z)QXb>C0(+ZLuq}jxS#O1qz7{jd4Yx3gm=XK`ReR%ofgu6)mhPUt}%q)D?)-Qe$c>p zn{=C5HRC$zDlhIgiQnQJRt0FzR`;uI^ENa7WN5AMawA7f@WF=0GsX8!wK>PC+dl;w z<*}ObY5r}9u+ftoqvR<)S$F>#-BnGGVZVl3G>uBzm-s(tWOQCj-8INF!9QXoTeflc zZR(fpcDO2rAUBQeUIx?!l)87~=1OSkx5yQ5F*VpGU0ZH77QQ+mC9H2J)>X1wh|xWe zM@_kb?qQ6ePgKaw1mPcBXc#HocWUSQXcgz^YID%<&d+o*uGum8!+!gSn{3NzlkEj+ z>$@e+{^r`sUFG{x_bS!3M~0KvC7e!5o@CxJuqL4UT>8*Mwe7yzs?6hK_a)YGp0~-? z_;Sx#JWsOyvKZ5hI?XeUVp}uqmu+Ewu{^CWHmdcoFY9gBb`GmPYFOK2-57uP50eBN zZC9~(d^H)FF7sn^gy_GEVY*GuFo0~&pjGN*;@R!p?@kn zD*Vfs=1UOoK98jMLoZ_wvV7spp33Gg-YyosQN~EZ{p|GF8rd!#2kXdR&g+`iMSTr% zIL%&v$420%ootK$u9kk&8=WIC?8CT$1o;`_;;_e8VH%_T>LX0l{5W1NDgy zwjad8UdxzTfVm)r#;oyqE%WnIH8C(#ONR*`!dIKgj*p^hkC7WZdm3SuhC>O=5Rz5o zT(#l_3tx?IYsGd7Hy@eof&&!sPxOaFH!7ulIGJ2^kpEiXGg;iAL6bSYS4UTujWr4o zyT`=mCK{=9D;UXyMTv!9T8{198nuCvCn^&^>9ks(2=Q4 z_^v`C_(Vg1VrwgA(Wi@&38;uImUw8w;P3@7Sy2fIrCpcrO}xCNjG!264C_f+3OFF> z1egiL<2yIm1i1Y_C=iqtxM7oIX84SPN1aNPO1x zu9*f2bZjru!OyA+Wd z=E+zGqXJEjfv_(E91-o0%w=IN-Ux`V0Z3I$^T-ewOm^63$v%qe{&($$J#Cej{V=Y^ zk&pv+0ZQWEh^Zs64n8btgh)el`QL0Nba}+#5aAmHjRvQ}VDokS3PL%-W=;S%K$je> z251&xt{LI*kucX`cyIlro8X~@c6MVU&k%uIaaB8Ug#O>O;n_I7O=5V91Lcsh zvC`tVO@tqWmW-TO0FlY}Z*zWrjpu+@DB}&ZpBsJ+lR$9j_-c!hLooYfhlwTZq)~dP z z)r-UoA2%i$Mkqy$#)1%^!wAc<$7kytoSnZTJe-JVh5~!r+EifeOCYrqr(DZTX3Ix3b3=%XHP65@sp)F(iZAD5>xFnUKiR=HkLuC*LJN&f3wG#Z%_02`;v&~7kF9}7a_i`J7CJ&xZ!h=02$kg}%&EGZ?iUh2uL5=#JVWF&I#z4Yshkj@aeWB)}UfkHR?vnTlmuWu+;IiYtD zXZJa|nXm9YgB>^#T|dUa<->;$&ruL@BB%+*=A@A%DECuX9qR`T9=Che)5@vMa9xdD zBbFD{FXC8=jx_@AbR>}kg8D;v^Ikw5<%7>TEIGefGZR@Put|vl26-KeU?Uzceq0U% zhw(3^yN)XUrioGfLq|a%O1vbSr5y-c5rjqmh4y!$i*PnWNsrAdWC|uK*RsujC z;TNp0TA4Ekl|$XpQIv$We>T4;EGwH#))uYad-6H$Bc{u#hXd$IaY%a;*-&u zp^NDiyvjUt&~EcAHc)+(Hb?>)#g=)%5*g0nSuUT9Pq?CWB0(USti1zdEE=f%?n!gr zyIy7WOcrcO)^ISBGsRFf~3h}tNky6ynIp;WpFL9b=bxw(aDh38Ox~arx&4&VL-KxDc#2oNjuh`@{M0% zTX*{MGz<%#?4(#=uQj(=yauP`9o|S%YQsKR0JzD22EcMIS0^W(NMKfeJ_)wMIDm{w zP`pokesth`%k;6^g%c=WrF{O_AY5a8HHIL!@Uem1z(MZD+UoQghf1{L_me-8%$)_~ z<#9=KzhbT|2Vv8PTU`^WgN2h%)#CLK-MyQ%q%2R*lAR2%-ka%pVx21q+>9hiN!c-y z6@q$tsD9f;!roa}S(RyK+>8AYb|5+{n8ibE8E2YWLo&PtFG$`U; zq@SjM@eY~JF2L>pPxF-LCu4%+U^gdn>fiXU9e-XfD(Kj4`LC@uukP5-kl6tYNM&%C zgE2|c!SQ~Q*e*1J?pUs5tlvQS9;oA;JG#d3`{i|}fh>R{SR;PJtB`}0ei%uC;3EeD9HIPmw?$6MGymN{5&EXs)KqQ|(5i9# zinYXI4z>lj2ZvwD$wTO?P)Uq`ty0F4BbrQ-c!Kt)uwu@P^t-OE`>;~1aUuz3jN{LM z)QON11?Lwd9Ik7-k@<&0#TK)EDiP2@a5fi%F@cCKZA5g8A&C;`L5{G!O94z-9KaM5*P|V|w8Km`N?GQ{PQ~ib;#{3cGUqH)z4zUzD z&%j#QKoCOgh>a!h(4~_V0?>)=_NVVBXFTSQ{^xcZR<1WU)*eK-8eh&A6`%W*zqXx% zjx+k3^^>=PhraoI`O9mQt}z*_I~>0Ej+LC5ZA|dDy;&*7RUx|@bcbxzL(NNm&~)xH zbz*Dj=csf28DD)-?#ytOyJy=jff4;JcPm3aA3Ad+vU=VyXXEY&<9Cg#F7u_=*r_Ji z1GiZ@2>jaIt;2b-Azq~Ym7Ca|_7nT`TcI?|4#Jiqx-?xOy zcN76mJP%D%KZyQqk9xN)%Hf{xW>bfi=X6~&4|DXrt(!|&3O+ilDuo+mUu$rw)soru z@d}IJF!SljG+Uzs<#8Tuwkn^(Zl0I+=h*S;m}<)c(|WG0L4moRi?st1#c^&ox23Xl z=MSvy3*?oOwzEmLk(;8tcew848DZCwh~PN=+1NuaX@PHc4U5DUUs5U23*mg?=*Y(O z^?Ix#TN1~G+xIVSAEUn@_eMC{NNVNIiaBkqLZS0UqoQADbf`pC4|#1f9D5&;O;Z`g z+3LJvwBbQ17gL1#f|b6_GmT5ntZWA4%r6_S?-t$ug0byT*g4a9;rJ_~JI*FxgnLADVwb=2jCFxYnr0k5h zTWJM#lpIcR*CyH4^!0@AKPI^7w$m&jsOxN$_HDIGxoR#JiCwJQhjmYQoa+18Rc;!_ zW8!GKIV(V)PM%fJ+&F8SvD@jgp~hzela6oJY-zYTUrOy)a5L44DYIcXDe!!UZr_rC zV2ah4lFGg>dTu%ne-$!$7^VadlpWqWPq*#;uDUb(pQ=vs*b0>gU-8`B*EdAN_~SW) zlj#%Ttb1am?twH(nz0E&g65H-HIY3PatbFOZ!x927^)f`%<(CTQTdHXfzJA_r_#E0 zRcac|n{#NamGqi+yPHT|b@aY+!1pkEQO$!7Sxu~P=c8LpQZ1YoszR%a2a;*pgTy}V zSq$So9{Xy*(urFpp~%Hrhx(%B&QiXms0wu|A^K92djYDU_q)$(ceHO4cq*t{-kI>& z=Ew5`UmX90v_{rXAFhuTio0{?EjDIt6ci*!gftW=8o*>CTQE@&U~>o0;r{*m--8v) z@N#}fq*X2eQZnU-o%`(bkJuNw;dzmtZv-3$EMjiu`8+2iNuX_D+M87icJqyfVOG|2 zq?$*y8j-z|uu))w!-)T@oA@hvv0ss#>!>2oNHJjFuHD!6?4l*Sk1;78KtVE$1OQ6L z$}b;27%8nSyQ2){R1SE>A}B5-H>Sl@vBEpj)q5`&jW-xq9!9XBKZM$qMw9{ zE0RYkh}}>1@m6Ce7L)yfQh0SLo>GRDk=`Q_xn5PGMkF3R=LXiT2O-y zeeb)tPfGVO;6B8hhv5*UCRw;7%|%xizMN*KxG6&zw+1$kSfS!O<@$zw=p?4>^@vd) zP4kwtX=Gs?xnTDl?N~=!8wF=_JEl`O*50TyAKZN6K5BSHUS^yaeX zsNg^pVdR#Jg+K($sRj%Hbu+U&&mX1)&XF1hT%Jv>E2u{BDTeY33awBklCb7Lu5lDf z+)DnYPj9@pvj(eb8Kz4^xUVt0EYIh!x9CK`9qSoSein#g(>&7lQ$B=W8sXam<|7_g zXdPi-vZu?btbbjDiqR zJLrN4kp+lvz!%sx9i$-tMCt23kaSarEHd#0<_}{0p3L_MyWIl4BGHtgARX^s3BP9?(Itwtetj znrZ0s#=4*Egi|xTt;471fuy7RBg!V?aRuSk>(%l3C7+R-CxWNyFYY^LZi`6)Scgih z3)Vn0s5#xCWLPDA8Qvd2#F}%aC=h;KQY2PZfCf1C28tbip9W6Vp6D!b-~#vbwt!W{ z{m^6=PWlm)kXzD8h#opG44TxCme$nN6qs;OUVbt0Z>&(*|H?=-O@cedg*YO}q^ffH zj~)fecL9~>c!XHWu*QH|xvk*D9{dj^t@i1cli<}AHrfge7;1DxgmTKc@S}Xna*;^B zdDC^_*V(~TZ#nK|pbd|$QdR$-3qZ(8$gxmUtc&a_Leqm{*`5HyHa3Jg^m|l!?Kbd{ zmtH^LUZteQXY-u+Oel$1$yhzJ6P+Bu*#{wH6g2c9L{jc*RoG z=RC#z5EK@^{{8#+;kV*CY@{zE{6P#ZNFPJQ!Z2QE_*sF`Ri@|(7zM!>jik&{SofVj zcPVgcaJB}6HJ=M=GEj%pAzE(o2__>lBkgmk&umqIxZ zy127hQa|^YjVaNWlQwbvLN|?a^)0YHz3ma?6I^$eL*i5;_>TNamai@qKbIoIc+#4K z7mK7km-4>c(xe-8XHhwNr5+j@za^wY!ZE;NFmvRR_kfB^3T{|Pt=bJ3 zCfVAYBqlC|t%`J*FooSGOEhdMF|)$mB*B(7XAIsYfCx*n1eI21JDyj3N6{-~@zR5J zdf(E0@9BOruib54PkN&%ChW(X63yVulnrbPjgMt7y}5mH>Vu1zutPzEV|CC)PGc|7 zmi3GH&_p`vd?Aq6>+aD!jrt#locrsh0>^@FC0oNh!$m~yd_)t!{z4dWgp+_v9; z<_OXSfCkP^b|*OZy@dM{+_C+8dE9XwFo6mMrb8U}pq9HX$NvO`SOXCnPWyNqERx_! zB<-=3xO=wL49@!Oz5%(CQpcRILA%=TtNPkX=5?TQT90w0{n1&Axx-8CS9tTA-YL0q z^R_!4T*<-ok3t26&&VyB8>)9ddGy?`=lAS0YK2u}!yo<3@or&T?D^=0nlg4HRY`@8 zZHIym_k_kA*B61m=`;cpW8Ojq!R19f=4I%=k(9|B;oWgKD*Rya&>=hQ%!u}%l#()&5i`O z>+4S(-Q}y1yI8*W+Q=jB=sT)WBRuD|<^nn36RF8GVS0)!$z!Wte6Zi+s79rAZ>~f- z{JxlXUZbJA`_S8a&XYcp=SuZEhUnsW801vKS$4jjoV?g@;B9olGp2(XA6_Y{b<~Q8 zCYQ&)e9Nz=C0b>6*iGeO<<`3KizM10@<`EjOLH>ZD4;A}8;)5gpr`_6scFeKj@ zot?t2lf>AcRrZ#XcHU=We0}+2+qW8X4$=|~*&oZYgbr-GA6QYrK5dXa5Z8AI{Zc`H zQFpEwh2BHkrmOq)cNpcj{qp3jIMvYoDd?NDq-Fs7txsPbW_)(|xVlT_ zQW4KlOZCQ{J0jN8RyCh>>OG^GPGm-XnWuVv?)WUfTu4uwG|z6TF}AyBy$ml_&_Cr6 zaEZ`5q#xVnF}x<6a7a&XX8)baXcqY#e$UwG^I}cw2lwwwoH2HK!NGmdY2k!MJJ){H z#-_(?pPHq$O>y4hzhzl_*HQI^>V^r^S0$lLr_xL}^}_ExKh zosDemml=+H4Gz0*{LSmo_>hOE^kch|E}I)21k!$;URQcttH3ZU$+1ge>=MJzkFTT0 zw})~3(7+V#P208jtB2`+h7T5eek^4E_}P<;_3{j6=U5IrWbNN1H5=o^7Qd}_Jv+L9 zLXQn5>5<>|f3N8_YtXlKewIQko&ar-7ztWNP8IB+$Zg3$Mk-(@0y*JmAy5Rd?%(uJ z`i(D4GnL6y7+k2LSAQ5#b+@~Eh{g-iQ@lA>S|4V0_uV^n;*LoXY`uiA7;@|aGYj*u znNhJ*@`hyB3FK>l1UPU}0{NrwCe=GMeetQOk)0E7M%?HJWu>C7fYdP5W_kLw{$5TY zz>6(|XTz;80(8f0A#DDVA}=}cF&&wEvhgaaWpfROl;iSbxL#B-$0d_;L?dXYkPzLG zcv384FW;b*2gIKs2-E!)2;AJz&h5V-x-!-t1YQM{W6etH+jTM9nW z-IB=?@h9%1|M7$3_P2l8>iRNyyG}eoiJ`O39q7de!64v>+|!tQ-uUVk{}t>*iXgU} zK_O4lSd^CsnXd|j+keQzDLOdoe9$CKe#sNfDRB<46a1nxrDolgmk*Ew5?M4f+pDQ*HpQpl5sE~y4Q3Tc{v~uTId0nM!eXo;Lf+pr@vG@@ zrzf#;MNQyxi0jCXHZRbc&phiH2M7g=vSOE~Dd&^Ko54Y->2@wd=`9aJJbUza{JP4u zRX@r^ZaqG_`nzQi(d?`5b(RJB-Wqq)ejNo!B5*9cO`4v9Pr4{Qw^(?vqB?1txmBHw zxjnj^R3-{sEE9i=G^(&S5nJ3*2-l~rM=ytp8L5MCcSZS0v^!SsZhgFT#$W*W zNXt0Z^K4plUip~mA|f(bk2`_Ag-T;DW^o=COe6kF+|(&()MIErK+cFxhKyUsuK~WQMXTfL>>Qtvz)!`9bB5QA-@v`uUXqsM z;?1sE=5dCOjqj8H{tZ=m8(VZLsN__t7A8<+(J2LRTZ*#>{rl~TPArc4ap21029Ebh zU>&qxh=80xIWh((6vhSe?cPn4G3c}`%1WM3Brbt#hAWtbNva&j1t^NzJ6*Vac{1?2-#ER>0;|x!9faQLVj<_ z-`_oZrUJ(ihNp|nt0)G1e1--P@y=luThQaFqYH&0N_$WSB=nUY_Qz1AJ>{${D$@P? z`jvWjkYep66w$G5=;%5As5Vj8Xg@jbC- zEgj`NF(RN47;)aUkRedu{Kl8MnTt;n_*fJRjei>7bW0J-ZeQV$SFrY7xzE=mM0Hq$ zt~~bcZ=|JvXOFuqGjpM_IDA{F9jA5W%21Cnr<3OC9a+E6C3l8@uCX_&a``YW5WWP#)o?zU$sh!glY^jg7vW#;Jn3X#m zse1d#RluP>Jc5s*dSk|)8(~*le;m*<`mq1zwj;TJ`@@&HcW7jtGLD?nWb$#071;UR zZOeHf^>4!xrIySQS^p9xuWwOjq-eM@r?c&Y*45q$=c=%QEannvDZ8FV+De85b^Zqe zTHO;|3ARtB6Mf%V>vDX2JT^$h{@+yk>s>LMxC9?l^|uz!uhyzimFa4m3-!n*9 zqc8CB@sDjZ3fHLJbN(9FTIm%(_3)|WfIfAFfE%}O9kG>fw7C8;af7Pb4o7Y2TRq1v zGE(&VDES0k01{gJwaOAU8d0XiT1seAMBbI~`H4C-5i zi+2QyH#>hCr}xTe?^UN6(YRk%ck;CB(>MKL1ux@dZJ$WLG%U5r`>{kp^ElPL)MLaX z3cq@zkabCIMJ|_#3;Tbb^9wZN77>DxrbQYfdqjF}R;c8f@Ku_(j`V~H^oEs%>-*)z zRmHHy7zx+&=~bKTm&&2o`TE(VM8A(a_vUtgX$vo9wvZMSq${=CGC(IaBFOoq$+3;0 z$SXWbfsy(g&yGrMEptE5F?Qy1Uw%s=r*hS1*M0ZxHMRO{eob*v*Qg$-yz*k3c-F`V z(V(qO5{|wd+k9-L=im7S{|($5y8nkv%w4_}o;Me|slTD`#0110xMspm)T8+#2_>km zBpTi^V~1Zrp6)ciy^t<*++j)XIH=MKwtA;~a~kl~k0e2ihWpgb(^GKIo)7ZMM>UNB z7t+S-XmH2AklKmi>L^I2{S=;t_QgKx?Cm9dU3ARRL4Vt!iPB~59D;~-)}vn zk{VZ9MHHA8AXddF#rg#2ZocT@7bEwL1`cN_*1 zYvy{C;!{#t3Zml_GTx)bj9QoLI&NvXTf(<&0=UB!H0VTf=}Z24%4RN!5z2svFwxK% zXEoX1K>sHFR0sWm8i^6v2*eHO8Pgai{1r076xTep`7`5oN2}q;E1b+y4cBnWKQ)(( zuJehngUdyVX*K86ZkHKAtG>o>o1dOJdzL_~By;c}hJ_Q!D()4cNCnhOfJ;zXdc2H` zS!#NE(LO<2x#hpqD;&%)RE;4Ly!7Rcdqz_}$zUDtBPKy^A^Wk#S)595g9}R1*h5UCY1TpI_Y;_`c@BYcM4+_e6wl%$$$R5K@F*(-+F{!3jSCK?u ze&W#t!F0p-bfYJs;e7GJ2~rONaY0W(1c;>7g;tsX-&g^&cuxuo$;1A(6qj7IUUzaE z<}YIay}rC$uiSw~L0!m(z>;y+)dDO4s;P;j4=4jM_)Uh3ty-H%TyQ*-VS{&iMdrM% zFa3r{60q)lzlCFXM8tnIKadymEbA4_qzHDfN0hZ!w8ZGh$UtA8Px{&479Gj2y@4BP z#?B;ya)x#@Q_j}Ry7@2`7Y^Os!0p*Rk~?1qOTDJhO+d&4)6j*o4?|`R-4`Yx9l3n@ zq0g}thA*|EfrbTtz1GUmc+jy36A?R*F0nc@J~e!-WIr`%j;b+ zgswsk^A2N52W#J)0B;Gz42Xo%skDEJ-G0nzs|9otDYG?Wp1$!tW!*HM&@vFT*Y}Sq zoS;BF?*^Y7NE%-XfVTimo9|?uRT*X!d;=C89TM+2VUXkuJe+^dTa@)A%1f`z_ub2? zP`*IGq4Vm3WrAS8Q5ofvT2C}8z@UMfcD0DT z@cJa1?{X8}5sz{@q^?g_!n<8q&Q5T@M?Fi zAp{W*s}B0JAVLOio0zhNq*?B}*2)+1--u&xY0hhJ!Wuyg!E-J^hMN9;cbZ7ZH^=VW zF8mT%LJdKDkb_~7=r#9)Ie;C4*ZbIZRYwE8*$*8_^T`B8u2w9Mga>-9V4#FO1ac#003|8o6I$c9w* ztA#;)A-B5vBdt1pY8Sgzil@FDX8^s*mH+i2W8jhdQ9^@^%I4Hufmyd_tTsJ)yK{=y zxp<3qsQ$++y_FJ5s*)U?GS$CYJB@;O#7VG!&x+f1HKHp+$Bf#dVwP*MaYQDit<&)2 zidg(>vEkEy;%y5SIEQ(y%6HBFUO$lPbj9uV^QrA=EkRuUADHKvwrYCaUb`EnnWU4s zAncn!Ew1p7C)=c%TP4xoc$?D3AD`H59+pBQS@Zcy;lFHinzCOlR=kxL_EqO zMQd&Ss@q?|*jpRS>LP^3D0GKhy8EmS{z%W?F(#?u=U3ntuxhOO;JCVoO)&?9vu4NP z6a0;)e9J8~AD6`Idn%aTxVSBFo3|%_-C=Vqjp5?X*A%Lig5n36t#S-!-ZOEzY%&(w zc~Di&@Mg zHPCx#-~@H&gNtLTMg5VcT3>>l($3Md<$J`#?YD(D z;xMDZ_sd8eq9mBVgcB4eWpCfzL3}>9Z0rR_nGH-VETnCkfi{CMZ3wQNBM4i(yYe37BEVwc%{QG7_BfpTsQ7Q^-+?)4gI!C5 zzk@vfRMfu6jdfeI9h5O(F>DzL@U%U0+B6_w>CI>gUf zU!GbCBKLC$Tn2@C#}zSmt{U+U(F1S5#l3>u3R^S4qUp0QO)U<$0)Z@? zVB#UxakV>z{tuBbRtR`GjR(O1=%E0BXJ=ETb&t&bmDwO@q_uq(N*zc%f5Tm$Eae=X zRm%nD2e1Ud@oPXRD!b3kH%e$kTi5-DiVVYFwu6HN;ho`QhX*^DdO-M@TNooK3mm<7 z`i$p&>S{+qUV%~z?`MqI)VejN0DKLQN1FvvbQb7 zN9HtFEu?eA1mbD&gVx!_>T7>&O^kk8J)-i>@rF`64DsQ;?@1X1a+J_Nl?D3RH&5jY zOha$wjKnsZOJAR2wrVl}3|w_r+6TY_TzDJ0DemmaOr4Hti~pZRvUDo6Am4FTKb#Q& z6KqM-M5BkVig*ajWNb&N8R6a1S;@twdC*JS+?qGYLovgD*5qROZpn){y_a7)-T>3R zZ=?<~BRJ<5zyiXI?OpXyP*)*m-dEF@lnR_K`B+NC7lBIfe{E}D5m28lCuE3cj}aa> zI^bO+T1Ya$jJU9fPY+ZeA>D8(1EASY%Bk1`N552x$tS|Uljq)f)UANbia~7VR z@8jbex(szBt4W?Bko4&pXrS{!;-J$8#R5EM8}=Yo_2KvwF1ltH<>b|# z3V%gQv-$GEDfh7>fTdu-NH^G;U*+q6_K~r=C^Dd%03}4`GLc|PNLYn>nC4y` zIF;A!=|A|QDlcX1a`r)pY9}u(Zv}8AB7PcW0H9}-B+hD|8$ z98kje;XB7&J(@tzRfi9Z0hkl7JOVzwx^PV4EvP+6BI6X#0cC{(@d9FAE5z5AQVP(dkoX3;_9H~)4BUxos&j~@u9IZz$KQx?Lkn^ci+dGnQ4 zTiwq~J2LOeA@tH_Z+!a&sZR5z$lD%6L5#fh%A0Ek57p~!l>Ex$o35MH;rHbB^97IX z$=?s!%X|4p<#tM364<)Q?`?r-&FM|X*&Rm@pRzwv#lb(rMH3vARj{aKq7Y)MR<{r# z*P&D27bm3qBtf{vtVcdZ&HV9VqVB_i@TTUid$VWOmDScd_C9GV8#UN7AJiAto*ZrJ zKJT2gy&ew%ml{s`K+^-elEe~~y_gOhmt;pY8C_3z(wp1eLwy`Wp zn7|$R2l>8hKR&AjNRYRGJD8{ zP!-X6F2$~>lnfc~sFA){zA=d=CD+5o&0*(yDs9DOv4oNhGy{0xS4(b$JlQ9gPwBT(#!dH8u^^IlwSm)K!=PxUr zPRepkl~v?aqte~D;Xj6oV^;)D|LW(m4%G%6Gc(k{+RfvTy6hxt>xGYpFTP~ zP4cGP%D>6Pq?zm5^D!wW^(LoCtP(-DrJwgtIz>fT?fqG7ysJTTS)GpO?C$@B1opJi z1uAKW+WCohvfmV{PTtt8q%v0x*(J`G$FPCvBT+*zFDl~yv;ZeAecb?vk3Zosd(2l% zD)YsWNu~?)S6lITKh83`uvk=351VZYd>&|9r|?H+A)l4sGeq5Z50_=haa>DYaVdbh={ChMxTasOBUpz0S z*e~sGyFLCCnNUK+!cdZ&M4BP7pcv<98Qa*rz-2LL#%lSkUi~i^VVB@7BkQ=&@F~krC)y^^8IK!*?`& z`*HltEFv6s@3C)2CT6ClMk$^AT~ev}6^OkB;1PXF=Y#>Gj6-FZ(^TJl-ywq_EBXL>ouoXh_BIjUZ>ftmgBS z<`a;D4x-R7G>q%}&4~m>atJ=V$w%a=&>U5D*I=~qV8dBgyb|Q?avwSmiO-oQzTN|m z4N{~w-0AHg79LH_rX%ngoBI#+ggpk`DE|0v-@ z$X|rpBX1s(v%8vXo3BL7EMaLWp^A*nS(UExl9rY>YD7Y2fCr>5kQx(JH0dfK0KB{X zDDr~09sbKJ6|FCFMvz%GL6Cf0EZ>Pg~eiSzC9yj zd$}3vbnk4Uy}M95q9h&m`UN+>obTijJ-sJO0T}JjHa1G*v6p=d+#Mam{oZxQj~?Zf zs=1MOh)Sb<67}o<^hVpCcABVgmAY%B&+ZU3+?jAyYbVJ!jp# zd%+JT_y~k^$;o{~0lEww;Zw@j`H_Cm!r&UOw~@f7U*%y^CBh;c6j@M&AVlB`&W-&Z zBjT{n9)v*$6Sv;dg-)LEn7g!m{tzfF>|KdH;J&>stl*PZ01d!{v;LjKsI% zdyX10X84_z;;`)qeP2MWZsf7|K+&p`C2UamK)+Gh6SXF4a#GHNa=a78A~JxMf8(}Y zFCZ)Y4Y^y}Zz1-WDJaB{EJ}FmnUN!BKLY83qJG?F@P+^A-_I;5!|oXLm;g@kC`@|r zy>TG*Pg7Hq!grjzY_j|m7bq)aW+)+fO~Bjez|T}*4J=eukLilSgcb^2o|Uty9%g#F zHy!1Kbh<&K{-t*wtbW!pEc10?j~bciG9<@Tr`<#59<3gHvr-4*cCim>sfRTE zE}?beFm!!eXkTTmT{IB+>FCxDvl7}9w$#>1PLcH5vCmv;w6wQ;u>UXi%(|6_?}Gug z-l4|s754G%x00liorkQ~B-qz?_{GvwEg#t1ta!YPY3ZinH^1$Vj19#(nd7^onRcGY z<5K%kwyI0D_hdB-v%FMU1Q%O4Ls8zVKM3KwZy#+q$o_Qb)OLp4NR4l@5{JS+KC_FO z(=>LUq|c_2$vLba?0#fN8$)qhF&p>q0C;tsQo=F|1&WwW!Y>w6kL|uY>tOy zM`ric&R2yMe|VM*R|X^{ZE7#wIIG`B%<)+g$dV=aRY1a@WC|)^z`z zu2-CLwNpi#6c?&QQ!U2$(-%p%d5dE>#@yAQ4Tew`hn*lH*# zOw%{o{;E%aJI|pq)V)z8Z&reDe_7|(^#q-HZ&Y;Wmrz1~!?)w(j)DcmhZFIUE{O8bmT zsUdYtgpj3Lz!8lfq8t0}t_6S0IG1iR{Ca)z+@BXcKi)W~TI}8;PFoprUuUaEWZSn^ z<|f+0C@ulH`y2ZUjag*=ykul$)1ZMFhF9EKHPCvY$@6s8;QKeLt23V-zL0x-MG!8l z_bXEE8H(5=?=SW!y|-58;cVGw-&f)*<*DTO4khPgS`*wck*(y+$wHm%-8-NNvlO!?5rQY z!ZFSEkS91Q_htR&o%1sr^mPOOUek||%x0u@T(d^~UTgKT^abP}+H;v(=k_HPd@6~wMY--iY?J3CuOECxbh z#33UmWMp0eEh(dBJetkyoE&H*n#}U;P)ymLd#$OaMuqbQH26l{glcvBfB-q1+hlGu zDEp&UR?pFVuY0{3%#shsJ@UUSD#i1=6a6VQ^VxHvLhvDB(28lhsu+>1BVfyK6$NJFSkAwrnw*-(f_!|cdjuxyeFZQma_D^04orZJiSbxw!Zxik4$gpt;5az6(>w^E=<+Nk^72?K+zE>)b@T2(zg#o^y+b(5(KSh7On?4ytY zQ6>0|qNsu5W@B4h1BmP<=6?~hQ=fQYYC zeSn~JBwZAJ=J0r)f(VMA2AIMkV`JaJzzh~wEyv#D(@_THOLezZ@4)(mx}^(lDWaw% zL?EFdZb3m}T3Q!A z+?$i$ItdyY8ZmosRqQKCg_0n~;pRBB`g6Df%BLoA1EDi+l=+)ES66v8JAN zb~=DkNuAklk0uxbAKdJ+uLb-+jf{>~A!Qq`G95f>G_?>v9?w~z%k)4{CX~KJtcClK zjOf9qA-FvJ**5@W_`*F6xwTrI_iM}pJzX-c`aX=8)>Y~LovrS-b$J=x{I462TFCnF zFfrYA_WcK;Hn_EISg5{q46HVyCL*WJ%-Eq_9gMip_567~-jDUXUQxC9t{a=uy0uU# zAs53S{a7;w2wJ@ok&-%=^Ns|&VjS^g#l+IMw6l|wOvS03evGEswxF0el7>H2tFd3kwUHCBl5gS8!4+BH;m zFl&i!cT$#3O5Az{OEa@JbYkr)iy!!&@Q8r51EDq9!!O& z4KMHbC~E#QcQNxh%2X4R2+?TlVS{UZQ$D=3n|Hp&G1jp5V*+LoX9%ykW<8hfU+%k* zKk_3gR3)p^isAXob4%lF8HOE*{UY9km(9V=T3xpx{Cj%xl^K77 zI#g3nJ>@)F8bHd_Sj^CG4FACpdE-Whciq{)zkI<;;hmCGWN+A8Yyk8p*BdXuJO~w2 zQ>V0L*@&>~rW+kO8N(2XaLRddOa`w%{Be&{DX^MdlHTwOi>7gof?0104gq`xufyv& zdK{%vEMPGbG0Gg8Hw>l|)a8e$2zz-~T6a4cCy6E#W_n^Ogq?C3VttnF+dIH3Th#mT z`TF{j#7{p39MuEWl_mz65MPtJ6scpOK;_`E9DI{lma+Zqi&vgrr+3p*Z4AeR$=Xwa zz68p}i%wU*oQea>?L(fzjgpcQ^xU9M!OK1dKWX*A18Oy0yknT2ZwLH4Y`r=($?!w{ z??A+jvqA|rzcvfzI6c3*C+>jX^?&M=5pI2B#%?DiBxg(m!Y)^*%zu@m@1)9n&2PG# z@k@-Fc0aF%Xi9^cMxTQm@67nLdU5>z{P`&E9Qk1kkGx{eN>(=-mAJ?597s}^sD4Bjse@~8){`#8z487^i+`raq zarURV&NF@IZOi=4O>L0zxHF$b-_55z(-QRs0FP|-c+ zX#L^w{dqecmweA?sET^I+hN{ugGNC_d9E*iar5plJ_GIZ!P4zAJO7-g^f*mjYo`@4 z*J9BWa<+V1h1$l4`cISs=*^VQKTggTaerV?dkhFl7!iPx1R~V;xv4z!otzg zFs{9FY5TF{Qj@)f%bj{#ni4PDFPv(7(-Fa^i~q89WHh3o>Tc0JO~zd2l;jVd^DO>4 zHBxa@+Y;*hQ$=#y53;b!90>C|vX57T&h?K?L2T67#BJ*ibS?A_5A|OLduZLGRIkn~ zoV42SFUxz-*=FKzq(qS|rK0Mle^N|EdRjcY84r%DwY9Yygq`@^Ss2Cc*`P%e6Y}{z zMWun6*u9a9H{Dfv1Q`AE*gn{^gqi=2yIh*`Yqvw&8~fIg$-+qAlF$S3>>f3J=MVp_ z-o?$NG7|1l&QqawEMjbR&u!CBUdn>!s@`u){K~x&x`lorFguv#>TazYyrGOo!)(8v zi(pb5ySl*K%6XYtmrg_WvtCO$Gb_DMdyrt?6}n_a4#p>}`_GN)@tl!)ro}lXC@?1Y zV3kV8U#d|u^nScnZgGCE^K&^y>F#L0W7{Ytf66hhl}LW?r;2<#zNB>AS7^_=qs{<} zY|qu-$8wJGu@{N>Chb&@W^kfBtw~YTxb@;+5gbzlIzj0J^~NhCtr5>&s$}`(=l8I| z&%w8~$sWbs-XQhYH#?p_tpNXzd>13!R5&I`l6`Jo-YL%7#rB5{TH%WkXs!3&qictZ z?ZTIqP5aZjQ8y7;rg#0+Mgt#TkXhrx7Midz{aX7RYy?Rq=x+(z#`{pn0Sei5;~Gey zQw-UVs2(>p)x^IK{yW3Nso+PlBVoFO<-U`cr(kt!gBzyvmX;g#?zQ$-Tk0mtOvQ@U7!R4J z4%!x;J@U}bYZ9f?L8wZgWq|C0q$}xWq)nvIUHm&!_xb|0;U0( z9V8?IS;j9KVUf&!2_?zi`to64aF*vEzA^Laag#ZA^r#9@w=8*^eGp`k zX#)WI(%hzSos)@E$k}6}3H|%`5{WgYFsS|pVJ^w(e~;k#QK`8himXL>)4LmOZ6(ap zI8VE8d?;pd2ay3lAeHZw4~Ng9vag4FnuIA5fEDhVp+7ckZ%j+vvYeB$I%e)d>=yU1I8 zYSkWD$(IoI!13P{&arew;iu_)Tc0m_wgN-Hv+gO0eQx5V^;;UhjRbx-%ZCs?D?lN% zpy@FQVQ&$JU!w*fu2rbxb&dp_NAR;%-UIodzw%JP(1xktIY2|3HupUR!v~0{#`;wO z)IzAGXjHS1&qgFtU%vSGDeNRp1q97(?0FF!-l76Y7AoW?A6DE0bxDihCPVC|P#9IS zKN%KIi~>gNJKaV{qxzPlr(%$0%j_8t&I`y8K;Y*I617AYd(AoPi$^i22sW@}<;w~r zQUo3DfA%xkVyBS4~M;=_jAxs zMT3qO@4BksbrOV!->82_IRkj1dq5B%yl=$hNDxHWMo+ac*9PNm_wLs*xI`O-VX?UzeqpdI9FvmhD>&#kerjG13R)>Sk+Gd^ zv;n5gPt$CjySk%3!bgeaXVoCY4`;8ydAHAK7weC?IEH$p;dyoccYC+HIl08UI@9D(lWhoh##o4r$vbB@U}K zhVX9Dh^=G}XqYY?Uf^}lzFk;Sy7$xU!TADDLu$(M$al9GtrxfO-IiI@rse&8`^9;> zhRc3t{2V7&+SVoLKjbZ*Ydp->pCQr^x0l|51T|+2YOzDdtS5_m7U>^4% zieusWOHv19V{B>Sj>xe7k`$cZpHC-v=%mP>!mVRPCi^(5zSFXER#@eGXKb1beau-f zvoGXhBHag9zB93!gC8Y-%JfYoCbKVXxut8;eQx^Ye5ej($7qn}W$KrIIHLaCzH(ne z*vYazxmf+ade*V_uQ}U;uM7oF_^F1{*EF_gCC!Q6e10oSt&O%Th4%?yD5w)%#-@ zeZ0QR+~IZ5m9k%Of6SOOnxFfMgFZUl?WuBG*vFe}(RH+(f}vN3l@DFKM7y7=en&_B z{z#`KN5PsSg2(ch9}GODQ24Tow(m>Xrqs*=w#&T#mC~%FhxmQ9`=gZ;4^U-4^t1$x zK>toCAac!4`P|IRymycneC#hY+15u@X?Vm)z!AjO3$T_Ezc(ozhRshngO9C$?Ja?r zEb>E+9ZS&4$T3SJv){PRd82$L35G~4>id+Vpa@1>bLl78!H7<$Czx;t!x|LtHR8!{ zKidYctw0ZkG-B1jz{XnL2_z-JtW*@h`wNL7WM~OFbAh*z$NO3rDrndV_*KC7Jyg|X9jHnRbbpyzf`>dKR;3KgRlg0 zd7x7Pi3vm@PA*a~;)J$^ycuW-cynT~L+l#y1d-bs7e8UcNy5U%u^bQ-s8W}{@NlC@ z;U;AXGTpRuLIJ5ETFF5!?E(S<0M&NbOWr4mCXg)wl!5~Zx#Nj9Z)W+1qS@i=k|K~; zj@ETGMjE`3-+1T0@zjx?^A_zy9(YEA!b8qY$aElbxUsnW2Z2_RsZk*DfzW*c58n^c z6*xLc6|P6f01_OmuAV&`823YE>4wqawOa*uf>L30nCr>t0=80eaBMhs7QQ`?nOlHN z5Ij+Iz0CbMZZ(6<1;jp(cFaoU1IY(QoYwqGAi5Cbd5}9?KUtDel@TodHUc5SIb={A zfRYiN+9eRkYF;=G84 zis4#L)R*%fZkY=nQYjLPcg^wwbPNq$NVN%`6?HS%2!#xWg27H?K5h{;g>356J3x2X z58)v3_j|sSyq?J?CbqeJv?R*8Br-D69&S=n=D})D)OJ3Q*Be9ah!s>_Isj;tL~&t9 zTOECre%DLh83Fvy_D5j;69XrcQBe3Q|kU+MI%#@3#nS+jEn3iP zuf1V2B#j&y2&$nG;j)idR6Kn7(t?!brsYallt>yysGBNyhJdCF?~xLjtl({wFLK0e6H2LCO<>1yG~=Ee_E=wClyo5EU065{(Dz zK%#HVN=ygUb^r>NGXz9Je9}0rv%A;Y#zr^*2(yu_K6p;txY(gJuUc9`UAOzRcP1;tujR;QdemU~?Yf{+RTo z*!N+@>dq+UaA2{e-)8@2L(()b$(B^&R!@#V35koSOQIymHBf5t5>c&t?2zb{Ght8fIj5;r)uC?-XR3SZsD&Yo|2 zjAjDb2WN0OWMUWiGBROE*dR4RybGiz;vy$bo<1d)W(gMrF{gs-B;ya~Fz%>z_Dnw# z&G)LXz50Y`Ly~zzoJZK?twg|uj8p6cU<39tTs2BZzWhAt=txG4kvkC-&fl*$g5Mz> zDI486fO@aP)mY-)JrmU7EstTCC+T|g$QLD}mB=_W{KVvLG8PhE&1tMKGVlzf4J^=P z=)=N59xf1~g(9)IAP_?kyQ-+Tqv>^A?|Vl$H^v{rgiaDa@T{ZRsxhn)0X=rGOxOQ#EPhq#XxaVDiQxz6uP9FFil4=U>co~SySqS1^ zxJcv4wc1PKL3QB)Q)iiKW5*nEG^HTb15R{xMpmVGytV!3KD`QPR-B6Us&kp~+N2`L zdDfNMxK1YDdCS{3vmNHbwa;vZuh7juG1jmf);~!}qevSTE8!DbbgweOe{V9IcCwz{ zJ`OAI>*fX z85i{r*{s@mfj@4V{<#!hD*u&2m6`Q(lX}E(Y$A?oJ4Crn|w(0lk&f_ z``VRf@p`k74c0eDU+$Yx6fMKARy z3e$BTe*GH9PuU$WH4E1An>Dfn8q%&lwTqb37rlP_=vHf`WX175O}a=nse5XNEA70e zIob>4ule`$ZZe&ERxEAOk>WeBW&X)+SGr`4=j%fue_d%oZQY=_9B22NlCI6&$MeUr zy0+^>ya$GV^r>fAJZWeR-~T4so^q0jY1FP+?ZoIuOTk3_l?e$yEy`8rZahum8kr z8v{F?rHr>(^r!J16&SqV8e_r7uTiHob?djF%#W(R>Rto4an3N+qhjNNkA>g9`k1#U zpw!P&3gyBJr4`<;XLgi7?GCe*AK+EU4e%JIpfx`FGjB%=1Jea|`p&FgR!LfgSvzIQ z&%W`0KEJO2`KPB^)-=l{WeXqulRqi@1s5fM9&K0s5$bw6q)~Pk*GM4SUM{BZ?38EE zDP29UR9Q&FHDuYIy z{>pZ}ry(`hnAEcwE+-5q7pb~Ixd_XntvOcay3RPUWk?ocN`gQij6AmK2*q_eIA@20PkfOE52bI zBxx~>DT{AW!fIxnO-f-ugUTr{^OiWBT zFx|ehQvfgVFulL}A8+y#jn{5+>wC*bCj@hb+-oXvjQ|f7(lExP5Tl$u6hae0sPc&h zAEl=InD^r8!k-?W3OzH0M(v=^gLYU2X-#QPF-HPUhHq zGk*Ym%10D9gAosKHWHNumcOeQjeg>MJ@ff=H6mg94EL?G_ca|KIRO z1T`W3b_RnWQBlXCgiFlGC^K#?sy>WntpGV9R3pm2GH_D63`n zCZEyn9`G_q?{|`dPI)r{@lneX7c%4_5m&C%B>#t&kVsd^Jw<|XQSiC_H1=mBu_Pc} za<8WhF63(8*hN;(sAZWM)=|OQIiS$qbxo!%{^2blNG-%oLb_21IojiPc7bn6<#r!_ zyn-%FM_)e|`Fwb9B*sw1@+Z-q;t#<0Z5=}q*)^LDnf2Zn5!&Z7)lwm9OZBq%XSMaKT=0~OP_W?u5VJd!!gX`f_fjz{Sn(*+uoTAH+axHkI|y*8fw{uwH`TSG3&7G>*J}-5H~d ziQ5xuQWpG5Xwt!^5UC9UEy;9OubpP2f-5Nyf|M9;kS#fWSs3>?z8RSifEJN_(w#dE z{QP9Fe?7-4H|QqRBSkyyuOVcwCxU(p**5w)0Ba@Z02%=-d-OeIVgdBewjZp)4Vaj) zl88}KXeLXBqDUnG!^Dyck$1qeh1`ABk|ZV?O0XjdLPsE@nnL0tKf=sWc%eWfDlfF| zc#H5#r@cIZPeKAnP>PZt2eFt;!yMKi(GU=l4m?54MI%J&knHLU_zry-W8A9a&ow+U zQi)obZ2lK6Rkc^KY6!oD&?E$+A!&ob>UKG*9c*^;J}$wGiZjpc@)-zMPOZx5CfuJ` z!iu2QV+@8SOP)K9osx0o5nMSWUo){&!`PS!dWg$09$Pn3V+&*NbK$$=P(sHwu3U;O zOU_oD5`@44p$NXEt)n9mgMmy;Zn;dt&~7z60vd^U+n{jYOT|kP2chgq>%IyS6eIv* zL~UgL44`7gOw{0;2?|8x&pFq_D>k5F8|d6(hFL>7lDA9g81a za3aeJnl{8hy02cv)5Wf8BYYdM0Rx1myKGx*{E7U^ui(l&C@bDJ?08UqEaM64Z8|b1 z4wS<~q^*?6&QT&b0RO9guQzi0@_E{_5zkk4#fb^#H>)e^3!KF9?h6;p|$%!S@zlP;*zYsFbhQE z_zmFZG{HubmI&{18hUOwsdX$C;6qdBk=}#H;)Nr>Y7|5bPSsW*CPxYkGpfJJL4$_8 zeUSSP6uDiexOHuOmB^ho^k0Zb)pc zd2Q)5#rR){#jN^%L#Uny?#|%jTU7p+FqnX=rbgiaj4L${^TfMA$t)eBd7hFlTlj-y z%n%p?T_agi7WY-GbHdAI)4B;>h=chnX@CK35jat$;x2ejl3_reFm!}WnoJMvjJ&@ms;ojR?E&y2mGb8j2Oos6kMhFbwlr0F#G<1e#aL>852dq z;peTMoKv}*`$sdB%d%fwa-}aPMsZcXQ~3)IJ00s0SGPv4$LSuAUp zBQ8L1kr2z_8Av6;^i%inqn(t3=Z+rLrBM%KIKFRf?2fKmy3s`qr}9_XC%#uGyij7@ z$uUV8uaYtpNYS=6sU;BqFqg|YNbg_m!d0hj7H783AMdTW7ov78z~afzs=^H&UqT!H zqdcokJ!c`1ZT`fZ?eTJAbDpxsMLxydanCc&RnyVq>rs>*3*$rUO4PRT53y2v2T6V4Qq0r4s|e)(s#&8iK>!|A%fT{L%cJu`D(|#$F}9uE4A-Z@CkKwMVqE`S&-Q=DX!7k&YivC z{yR=d%@3e!9!W8Mj-xmke(CDQ{tZ!sPb$`*N1RQ>;Az|WCfelV^t|Ddd{eI%JA(Bz zBHhQE&(!D-Iqj%ut~q6TER2(;Dx^;Gac7EuRfM*g|w94265?KdXE2&YDx^-`7{!qeiu(a@AkjspKxv1GmS;EOV=8|8|OA9*sJ~Wp^ z7KuSKzq(HRQ^ReG$=G+sfY0L~14M-Iwu6D+MTy%@A;VBnp4@v#`$VbTPU<*RJ@t#l zB?wyv;t#RQQ)KG`%l1w5^ipAY%MCRH&`6;}hq%i}k6?lU#x1+z{xe``uPlvJLT)%A zg8@wdU`O^xB~AmBV%p#;Ko#4g@-V2h?A=~`mZ(mN!=%i*4x^4Tg$>7$>Kz0*D{xz6 zO3$He9>03;-aW&Kkc}~!+qZ8gx?Bh=YXXX#XHkB(F9i^AbIGUhhYqeL4(QC#c&`Z++m z8tVVLW}o1mKeA})Z39V-k$Vc+=Krn;m zev8zcX%x2Gb7y#?`5{5ExwupVSWZ{bWCUp?Bqj{PUR*j>+|Qvdeg^d*j zYnqJah1}^^++sQ=lir%EhqNE|Z!**at1-27eNj5&&;<;@T>DJCLPWhuf}Vj3Rhc#x z+86$ig@70Y2Hvg{%GjlW4mSp4RwB0o9f0VODtv9!h(uQu(6sZvuDq2C_jh(y@;glc z*&+@|l#ZZQaAA|AOVSfy0t-oZ#F*RB5jG^}6c%QoEg($;+A5R+5Q2e$BlQq;sRc@D zqR|A=M}(X2yJt)YFd7nctSK#xB_RQj9h2e$3G(B&EsGVx@KDHbmosLYRdo}f;o>2N zO5sGh^3j;5CmjHXLfD=&xFCx0pjlZV$hHFr>OnR#RQ4sWkv8W|B>RN90k$u5cR~ya znPiNNvT`Pxv?{D9pwlvoLLp6`8e=f`5ljX8A!q2^FQFVD^IcGEV3?T-$ix&nxG?0m6|krg90)0pRb!P8kBda}Q0SH{I{iGuxLTN z=oZSDJW?~^5F^Er5OrI7yOjHosHdED$x_L^doBPmP0Y+{E$$5S{zETUSX%lnbknZ4 z*hOAUA0ILnem1HZ=ya08BrAEGqNt$Bcx3FhC|QpYnX#q@8wuteV2IW!sF|nf^}VH<;mw%3r{xPY3Gtigky0OZfWgouihI?Z zjhSTgAbCR_U4abODB;k+%wSkn2tW#)LW4hM#kLAPhIpJX*z8}@(TS6JR|F!*^#cIh z!=pjG3tuDwGmda%5gdo+8R;?5*udVgq`YZE4yJ&|Jbi*!D_cv|JZXfjRPx?PCFkiP z8`qptqD7{CgJzWih#NJ%mOPoCGQiaX{dTd}_lFfzL2BD?9N_vk&_rGfrXX%9Ov ziEwyK;S&?YWZY^%hJru5N1Rc%C3^Ol`j$oZj{ocy4_v4bcUy#F6RN$LYC$%7>pB_T01Pr!ZsP^ztD0w@Mx#3Ww_X zrW5oH70tHk6AmTX7nF0RJ=>l()L8h17!gat5p@_D<GuR*J*F zx48PbkGrStxRRi=v-NDZ(YDP?eZP6JECl_P6=tO!U1(Y)s;vw1SG0=HQZt5+NK!M7 z6fO*EeKD}yH}pfcc2A^#VD$fK0b-L{BpdpFjHmASe{`J*RL^Vs?i0~qs0^7^Qi+lz zgeH{+V=|>kJBmmu3aL=2L`vEuR5B$~NRnjE5M>)uDv=_I(4f@$Jp28>>zuRBI{&ra z|9aoOqu=-Yd!Bo^uIs+;&iU?*m!xJV3{f`s{%REU&o8sfLrpg<2`TXI7H4wy#OjH9 z&(+57k@nV(e7|=8`j!E;&yThKwe#iP-zV-#Jn$Y>=`%3=%D#g;omX5ky}bEJM`fu% ztuoJ>L1vvUtew@>JKjJg%k|cdg**GMFzV^MFSg-#LzsQw2RljSTDQDsM?GK0`&4wC zZZI>~-Q8XI>t4NjwU2#goK#R_n&iFKKQL`)lA!dSo~k_hs>Yg~A%F}x;jicrg)g2^ z?CNK}TxBp$VZY~KiGMW%&yTbnyKG_M?d=_3bT3JhCCSBGhfYvYNu1u@??vKKvlGu@ zuugtoQ(W61S7<#ebQ@B&;#E5yrldCb>B0yTOQ-)Xb=JlF^VzL_pBq~ya8uP|tdk>T zwi{ldg@3a=D_iVy9z6O@--4> zoG7?1#devMB_&!1pPq1jaB&wY>5h(RqJ(UZ9w)9{>z`3P9lHVxg$wgCzs&9K+I#fq z(P{G*o^&_r@xm=UGIDi+&Wqh|%sk%r@A@Kf;hZ_!a-Ee&kN$pU%cQX_Gb&xW-`L}@ zrFYjIw-X1wUNmLZsxK2(IG_!EfaVz}K$$5PmJl2w=3lwi5SAq7u#OiUOS-(FliTUY z<@DNg&}D%w+6W}I)@S{>K5+8s+b?Q<8OX>xc0KvSdeEj9dru}cx5`=i4AAQz2@nKV zhC0_{-JiJ~qoGTu4?Mp*E%9_M;FrSS!KF!q$g7~y472514_j|=vP0<=J;QAUt=WmP zH#)mHoPKKO;PA#(TQ=}G;Cml?)wIda8C?4t@#>u8(-K(*z==5ToGV%TTVN7m?=|o< zCXj<*n$N1|A2Crz!n>iT+cLl;vIO9BaG#pJn&Cn;vEg`AlmB!zht29^N33TePwZ=B zj(ZzN&%?n=LOh}n9Bozh0$TWiOnI#JNq3JMsdJN*Y`^r|clkR17gl>}Q`D7)39E}N zzhxOTyO{NRK7OeGW@$^0mf2(7OZ&WWIP#)yHSW7YrJDa# zZr!|$&I|Xi&2+S%o1oB(Ezjmtr+RFPe)$_4M}|Zl+}#eX?oqtI_VUhbMZZ$3CH^#?hC)AbdHT^km`0}Y#*4r8Q$40n$ z7~L?5_XybJb#`(@_}_p3t)5>~zIpApvCGpJZhDn+{J23usNHZ$1jKJ9bp)xR#MXZj zMp@E3x!?hNRQypfU#j`>ERU_x9!C z;f98oQlDO-bn1#dOUg07zRBn26ubw6TZY=`P2<<>#>K+Uh~EUI-{UGOQHM@qbt=bU z`pX3j0GgTIlA1et+>V3(g@?3$tg4#1uDzhd;2)z^Mu9U32Z`%yrE_Tgk`<+h6}^vUB32O*1>^(iGpmuxHHsk`E^;26fr; z`AX^%ht?2V%azM|jH)(?+c@||V6C}V!^a#s5I)-1A8|RYq1Dyy?2?e;u56ySl4*uu zd_$qv3$;wdTBlB?SF5|c;Ww&;wAV^^KX#~A{_S9WzrKIQ4BVm8T4s)w3-5g_(g0;R z2lQre2qfdNoLvr#*m?!ry@72d&2%%P#+L`)n3XV;8g&_RF^V!e?g!itRO47>%uaS> zOI^Hrub|Oeb5vSuj(*mCk}xq2C6v%ifTIT$RG;J7b2>nfQRa(((q19uk#)1p_ z4tIPH)2B+ef_gaIxGVK6C1mTsf8V!CtCE@I1xw45QBnQ#XWq6jFuSSJQ!`H9W7r4F zL1%u1I`@3mJ|N$v{cCxbL+`fXHlk1EyW{rS8N2;sx2HnF;0$W#eW${2i^25b$p? zQC4#p`2YF&*5~rF!QxN0+i~I0{sVJE$jp*g``>@v|MlOR*ehX^T!pHIy*QU%JVoh~ zZ~IuEpJ}QKzWvUblP2^u$Ecs6m4)C|RaI4|DPx$m(8@@_koO!kXmf5(zgtCoHV_vH zF^!U5)b>RCn%9Gi(*T&F)h;(ZcDX=h^lgp%T%P}1j#O^lL3EDfZ1HpBF*UsP$@j|d zg&Q|+v>I;^((d2s>D!OcU5sk9M=a{a!yf8^%m#|l^H(~o_cZnA7MXiSJ~}jM!^J00 zr_|lziE|kOm_57Tvm1Po(scL9ub&>~3%l!*shQJKwp&ioRLNp5td5}SR~f|b&(oHc zBvlBW9uq2tmPR>j6GxzZ+R6BbEOr>pU^3QHPj4*UGJF50wnjI-F?su@StnSpITmB4Txfdw zpT3>Uol|xOOKYXN&3$>%qbU9B1pF;8oHsDH=&(>ebcm+KLFKvTHtyEapC*qva9~+CF(Y%Ay;#2l~wmX8FPq}aXf z6zEylzM#H2Kr_~6sfLYp_vpYNhHJR}?d>ay>(t*K3$pOe5@r;l?CRQiN?g9!)DE~4 zk0C^Bwen6Jl#wGKvEE1rZJ{K|P~1N1s2xX-9*w{DjlW=k`HT}W4Gkz~-_Gn)-8%Zi zcYjfSGWViz#JZs#h28@6V4fK*Q@fZlvO+!*djyUiRTgSB(qH?sZOWvB)ED9NQvbva zN;Wq>J!y)a6~}k}vL~$a{y~{Bd(omzV6lqxCmW2Ioc@8S@e=@Fs7%{k6S~VrPKo|M#_U?l=zGxI515p_ly6m62ai7%x$nCxTPG}&gAb``Z1-P+_ z1^pVz!=Yr#sqOPTJVpR3EoTSe{Jak69)bGS%~2{n7Vo?50oKm3u@+FR^hJ(Wz1|&p zxJhvavYCAGX@0ly|AhdC9Wg^9RW4^_$Q^P@ec}E4=L)O@>PLm8TY0dbN0`z<8~%>F z_bkv}!u*uI9mbwZ+Uk}5Svi((ASO>u&n?SC<==1UcaF8*fn5(CShcY*NU?8UtIeAy z^E6q@kJpyj;&Z;|en;iD?8Wshe}wH+OIiD2h5*m6eQ$X<@b`*dvUhVXNwlA`ei+yv zNKb6$!yWtap$R>YF4{K5!6&2mPmWpXBUe4ijvb2_R|Rz8x|x_rz#^c^4iAswyW!01 ztfU*o349l!|@K|Nz_E&@A$5lUM7Jh5|d@#r&I?t8+<%PE-Fi?jM9T0y|z-nr0 zLeRFI`@3rT@0{uUNoo?BhPQHT|*+wLZI%WQy6U)-4cb7M9)1EyB*K-UzzBJJwx#+JK@c3QDF zW4pSZcpf%QeW<49&YL%HdMY^>9^zVN6h}JhXlQD(TSn#Ze684ug_y(J#PK!F^#R+W z^5`XNk>Pj9_i#@~nO^GxgZ$nz)a(*qlcE@FV_{+v+;@qybNSr&)vbpfRhlKvlJC_k z#O_pUp47Jc?{h!-g4{|=ODpT@M)5oAB5&DDL)6XmWW|B6FYR4C^O~aqnQWlzgofgS zbJ`rz`T%CWTwpln?VlTB0ZXxy`U!ZBR|#mf>&eDaYpmR?4JQUo4qCA~z^H6&v4fS**MR)2VvSF43k~^Y z#9Rcs6^|z-%t(|ME98OV<$?>`TT5Gvk7e>9CJo+G?sH(8dVZ3A)mN)%g}|KlgKk&G zoZB86s?e`r7&DbF7mm(L-TSC=e%{Uj)(MTKhK+@#!($(MRvy-`${ChySa(EiDE@C~ z^3f5`nbQ;RI2yL^eud-P4jvoB@&vC;M>Cyv?IMlZbFXS@RpZ-#Py7AvD@uPG&cW7W8Ugd z6Y`q=EV}k_N7SVm(SgzapVoYuxndA=n_hQE*L~e(TUWR0yiF_P6%;u%sS@+Qj3#{>+QC2m?HZbn~d(-5Cqc2vruLzz~l^q-K*`w%2!jZkM-foK)MsJE9 zbj)7s)X9_1AOICb?A_ACwdhDm8I!K9RY7fTaZlmbrU6Tzt;x8S9(V3?cg1@uPte#^ zc#Xf~^U5&2FkP|7dVbNBn4f2U`n@euSIlAm^Pc(+hBYl7%_kJsSY6$6X8^C;ICJ{5 zqApRZ_J4JCb^9EUKk;degNIhLgdPf(uuD&I8wf5hcP=oxmYn>4)94*v&mYLoDcTTF z=9cbT)zK=biggiBRO24C4_}=-Y{-za&KA&Hh~DnHD0shXd1+l38$IdT{vvck5I)=#JHE8o!xOqE}04>kPju44O=bJ}GE1u7rzHFBj*6M7!&+PUN2 zy?aR{J3U7gzBX)`RoE}~gOy*srD1&GL0b>?<{Ca{)4<^Kr*1qwT(|W=`|GeSojZry z9l|K9rgPA+Dm^u|5WsbK5AqA^5_BA|D({;7X-!8-Nh>QWlf+rGuy)|!F)MpYvCG5uzhCU`!qegDDKP4o%*)pOS>D0(N;KJKM>;{r2tKgip*1_`}yU zp^5vF@EBykDJ*xS(UF-yPd<9^K;CDir>8)qPQLh@YB2T9b$-mn#YMV%_v2}4kxb1! zFD;!B7-$5r%b$BU(B|nnUth?xDZaj!{(hBW8wW+$)AG%pVZ(<`(%0X|1{oLT;5TMt zL+Sm!!U7^IW0OvY4vcw2_7jR%tX`c0dk3*f)v2^)ncjhG67Q&#;CL8(CYw{0*=zOY zV^RYY;3UB8Qf_W8e;xLg|5R=!@u$tR#b%3IgAb%Ije3HCZ2ky&l0Vx^RtoQdLlgaW zQ=49LNNZY@>A1US)mQI~Ew5@SY3Zk&do_n`hW`HJWka{J6+8d%3+K$5e$ z*U=w_blY}(cSyznbvsE((I@9pS&#UA>coi?el@(~hj~2Ky;nK_G;+^{j@wM9x zl_Lbq?YHMq6y>|ETNWOge*Xav3o1!Qt{r#`LE<+Z!RM%Hs zMt7REwHO+UwFzaz9_2Qa8kF&!F@*<^WHIChSIDGh61)&EVQEJ3aBb}rqy^>W%%R{lo*mj033j~`d3u)Wiw_D3t>0G~_l{D*T@ zt_Jb5Q(sWm0Tf&fXx39vQJJ4tx_ftb&J1He+%g4uc@qW{A;*(2k)@vFF9C(Cy?CE! zlj%4E&;vAuo=UK2eqRqsKmOCpd#s(aOXsA(L%aAEeh z1lD8X<*TGL6>UIbVj?TqAmgt3eu|d$o^dH)0z! z7@-L@Gi5K2v==WW%x&0|i5eOwY|CGHW<3ArAK&Rt0Bu(2LeF=EKq17;pAbpVCet;vlVgpFserj7=)h?>7xDojEvFFt!0u>W5+4=Cn7jeo1b?V{dPM1>s=j_N2o%05dK;$?m3vh9 zU-Uwd%7g%F9QYX2OBndR)&8bg^aHsz^b8;qF#w=^JB?w4^EQ2rJf+XAObeMIjX_hy z#8FR0MI7~6D-ezhgP6jxz&IDBkCd}Y_MAcsFv6EF=V3UX)g-|_Moutm*^%(_#po=# z@tFOL1uHL0DmOL}iW+`BGaOx2FO03Pezd!3WpnK7?5o5xZC;+VrHUcGv?)6W>n zT6cY(%#^#1xedqI{@CPRE z(&fvumn)L1)sT*O(gpM5KnIN`fnzkfGvZEfz}>^XCu-!+tXkE&O5K8vouj6sd7mOn{2w|RT8F?3*66M>{(n~%*SG0 zPM(64Ph;{!GBEUemsi79r&{WA#JLpGKj{W&pOr*%_6Fjgm+w+;7_%j{ z-7_dCNZ(KtV1&N?c;sGsSCWz{w634Z%shbVnxsAJlsz^0tc44$Y1>Jf^V9bD-rMxE z+LUzZoq322s6gGi@lUwz1utJ(PMI=AsSUrVPMtcT!Jg*oszz;i%WsW`N6*)f_YB)} zf6As!6UlUF=cU`uLi{ib>#l#$bF&1Z@5Sdd&=+5_IfH4>rAzh?v_7rDUwgqE>vPtD zP74Ye{+gE8N|yRmYXf*2Pa#r-WLcS?^rFZ8r`2?;S>q2GURQ`h;5>+qD>!r=kVs044i|d?fp4C?NqNXXwJ-Z;#gk3?XHW8va-Zw zhqqGgH%hat!Ogi+5O<)dkuy&8Y^dithnB5PbL${H*mv%{3Azar#y=&VX1+_y zk_qNLV0}7Y^fZb>%3v_VGxmh%aVlW`O|`AK0|HOJn-MJIAaW4V0?>s|AQ* z%=@0ZK7|M8?7iSC80D2W?TCtNUTURMPLmWzO>CIuR*T(W!B-&Xg##-;YP)P%%7fil z$^f#&h#?ycLzy6&rH}cfu;?c42~i)T^j0Jvh7guAu`24Sveu8|$nI-T;)LqYls0>3 zghEAl{J_M{{y1hsy(=xVEhEgGw{>9j33o4nY%!QJaN_!DnyWcjUS7HcCpV@jfgqiA z`?{}QT?v*gJ}M<8K2^7a{=5nj3Ykj|+MXE$A}w#@x^kP=q+dc)e9J2P zb9Pb6&nXQnAkM@PHXgDaPTRJqIR%tVaPiD2&U^g$13TT0%{*dc_Q_!EShi^cN%gcb zJ@*O;U!yMLH6(vx164pEPES}|iYcJ5$ViLS7k0B}=h}|S7Ha?>3Y{Rvk3uCU#Ot)9 z^m(KsP7JQ3J7AFgxl-wE8V+ z5F!D%RB=R@@kZOk8f&}pS^Vy$JcRJ#Aa$KiOm{_Q$aGLC<<#+c~#!Pq6kcC5WW$Zx0fQJa(`ayjouC{ft}Fskb)OmG?Kfob?i zmB*#pMb2eNE^=OK3QZfCRJsg=Gr^NlesO zroM2w?0~dzIoLGuAxA|=XU3r(RAsb|G?7BFd#hQ%2GE5xk3;i*NYeJDMM%D0mXfEiz-mW-CID$b+%MTb3f&7z<#{m*kaY1cWLE)>kud1FS z7&BJ+7eR^+N)a4xntC#sy80GF;WSI;Nq4yMOf&VOQDU@{)7kVT{VB8Ioq58fKV<-^ zc*W|Yf2+jP?v}fL{k;Ke@WR`!I=}beYI1%Y+QK>LR z>IEIA3CB}5MPgI@!GnM6E|nB~kh3nrOPXTAS;MdB{&lY_OV>jl2mX37U#u=#W3?tg z=I24&1mGM*ClCKR8Y2lBKB<{PbUf3jIRwFL(!A9F8Hg9bjuKU96lk;5y1t`RWc2SY zIuD#&uERXxJ-(fNs^WK15Ov;W+V@*d^OH55qJVxUU})CL{r!vr#$UK_fjq2m($N2#GJ{uP5~G3M2u9o; z#7McD>ay>18)ne@I;72$u#7iYU4e$QPtnG4(GYYrW@612GHIGN25V`>P^XF|Xvh&S zNS~w{g1l*}92LZu_ZQzLI$&y5?XBxkeh6EL!F=UAx}-k6BzFZ3ndgMV5;* zRw%46)0aO{ygBwJY6vz^3mf-IcngpYg&;}HCMX2N-FyNQyBB*9@xK3gs`t@)jaDVYTSp;p(1entCzf>xEPMXvZ5 z!7OsXxOz-?WLloR|1M@>Nvh3^>>*ZmWwb-oun;qJN|aeucw&~5-sQJU<({GlLyfVV zL5N!~V}7otg5z4S3PqP}4o4i4J(yur!u5jm<&o!4VOj{o(7nf`Z=wyNB9a@D35r1a z#au$|qgp~0m5o>hj#FV6=u1*r7s5HusU7Vk0n&Ao=>|cukr-E%-ALeen^*hcodbz> zfPF?-g#xT3&C*2j{_j*6Oes@qJuW+F2R?fLT^C*$7)`Xh;136;QOR~B)C$uKp$M_< zomE^*S@f?Ri&wb`XN6*{cWWCsc|GF{VZj#=z2XYQB^sPIZ^T9K`DYev$#ZoFtv*Rw6JLPO41N zzU`JBbDEgXjy3TNY<jPp<+Rg*HqsTF*;xj+aKA~vT|3Y7M zr0u7M5rW&C8l(F`f;pi^7T*?m@N2ZL5OfjmFxs3Ma+bD3X!%bZ-WU!DCw>7M8#)iO z~ zk0nZG)EY&DA6)xr_5|?K#gfY`=P1&`auOkFeyQGpkjd+i*^2>BheeA-VDiimK0;Vs zjqsYxwGcHJVjL*yo?(lDTF@WqVpuxBrKi=FLqYA0K^mh*!8od$e5EA9sB{gH=I+4& z22X|euv0)DQw?J7nBKYmv(T0x*^>}K#uIh! z0hHdr1ma(q?QegUnOEt}jIa?0&l)atWAZSoEUI!E1l4 z#9w8^2#Tyh0hk;+oj)9;Z|&VG)|BCodH4NyBu<@7_-ci-?ejoHE)M$Teomiekoy4fjRCDt?kv)Q~cK9IbpY-$D6tw<~WQQh}VhV*w zN*G{0+&3K5Qs!Ph5V`3tKygr+J`qxA8oAk+Km*5$g<5poh!UKc@R`JIV7ozn#M120Nl#u6k&D_(xOI=8a*x4XvvAPe)lVcm4hx92X7O|M-9K6Ovn7c_DpU_4vBs1DU9nzI$n*!?FLBCoT?qhg(fr_)gNfVph1If z<;9(zvR}vPAth4CGumkkU87;_caF>#!^8MR=_~`*nWq^cMtoL0Kb~l>MOy?i6Q;M^ z#8q}jNXU5xw78{vXr}vp#cm4iNHEoy*Qdf^(6nGWKR>^S5S1Or^ZoiP%0mgEc3^^+ z*BX`QDs1oeh1;YY-k2>6m%sxVZyPUM^1RBAJ=~gK&=#amXag3Z!&kkyLhpe4iWRx< zzf-}RijCbr&;&+&R$=@d;8w(IT2LoRWn7LVACz+OIoUUZGQ1#alO%7y^AfwHsT6n! z+S`JA(;d3R+6-%3q7#SJor#Z6`c4VIdFplzFL21h`(g|>O-ZuDi0BlqP!}{ApqK^o zPK^1*m&Ya8XETgS8yt|D2H7)A^aa|kf_h~&HFMYHHUWeC_T5?F-s|F2LqEIB_DJR7 z5s5(MV#kN;oe{A%KvUp3=usnfrQRRBY1)?<*lwl>Da3`24@s!>14*$1_WJcneDR?J zpV$1lB@?%GfNZFlYp~Vsh<$0#COo$|WiLh`nL_%uF`LJte-zSRUIANLf21z?6Hd=g z4@vzx64Ns#{F4gY(SjwVrFBC#O#O=sIOZU7%GWMmjs=gXudhclNW$nL#2**`a&Zyp z1#f`_$=28PQy#6`Y%~OEkk~_luWYz;0@h=z3}W`~H3i7wlnFBvASH%turgN1NC9L^ zOl6WY($dnX2}f&S@xZ~PHL*yXMeUBE)xXQyEt6-v&V_Ol)f7!1IU)6hRs%XE=IWUa z(%I2tNS|RDcFQ->sewP;j4hAb_v8)i3vC8gU01vc$@X-?VrdKA60Uk|X4&?-7&Y7A zjnyh5s?i<+vGF0q)uo*Sv(H{U>Ry-m1JCn`(itEXIErKjeLtlFjLmT0CTf`=FL^u3 zxx$)|{U7G?_NQ&WygE08Q;ARld8T^TmFTO{;o0sVnMk>I>B$P#P?40WsIEewus5UE`UTuA-YHu+;+g!z z!(+thXwS^bKH@)uTXRjZOwM)J55HpWb+-Vh+-d7fkvA6j9tS!O<(`m|sJF~hUqn;t z`#zz6y7N{NxKh*z93X6KEHMD&Xq>dy66a+7dINxI_WBAqn9t3rP+LN|$Xt{umI%n$ zfJGQB+{uy|s&87$CFv=Z#^C}CU{#K&OF-@5M#6Tz;O?Nd~(cp*RMkLJrIpv&8b5LBC@p zs;Mll)|}1qr*5^z%!R81CJg=uXw_?BC~I4K^@*_flIE&m(|Xoki+SM`99@#KkAWPu ze)SqJ4PV2{hZOiuii^0wmQ5x{rKA_Lr1DD)3`Cz%iy!kDxQURWXq`u096Btk3Ajk1 z_nGLC#0LxxWT}>2H}LIHl%hlPHGFHVpZ?pSMJ;k>iXHYY5!i_!lOSDUb8Wq7(Ewme zD6fB3{_0kHyKa3JV{;-y&_ea>)l2Pk{rmUtS#_c~*MYsrD&UGqk}6ak5YUb~P!NSB zCA+X?L$BWH?C!ynwLYX4t@;R{<;E-qm13d~sZY4ekPDwZd#2Oxv{r05AOOdGd3q)l zJ8S^vKr|4JI&sp^Y_*Rn(UpC)=h042W=kU#3KXCC0&pc|&VngUVNAKxRR95ohK3|g zCzYO4!u9!AYT3wa;gz|ibxrKLK)(=9AQt7j=U59fPb21W;u2x&BO z)fGgXELq>?kS3(AlGU-Rgs=evb_7u+)HbvqgD8IahKM@1>aOj*uF~OzIL5-Vk4#O; zCU@R|n|bx>c+h0AXD7sX9Q7mF+Xi53%ZR)b-Zn}0IT>#2wr#K(xAy;;PX(mcz%;3- zb&^&tDfC0NONXc&ljjdEVWG5_d!(Rm10OeRm=F!@%^N~cAxsk?K{Dn%IE)N825vm! zb0vf;&M0Egl-~DJeOGDeo0Z;_%%Dk7+%W`goxc#r03HZ=CV&y=rjISbO|(N)rf@)T zeP(N5=c4!T-wqA`G^j0IS`*CnJJ^IbW0>LLuc7pD!Q?C`E_`p-*qs9&@R-{8Fq7gD zHhKJ8>ISg}o!xTOwQ5;x+w9?u&S?de3t^;|M)Q;zo9v; zm7^rPu359jjTv;+@2jbJE5z=3;k%Oar#MF#6xrHFoE)#Esj18wSjq&T8VdvzCui-w zr?KwJ-h3NoSotQ&3#8v08FQPBw<6XfPJYY2RF-g|-*Ab-e#YL038F3(lU;n55vK=( ztz$EKzAXQU@l0YpfU@H?-!Sev<;fI4DY$j-|G?ZIjLH@Oh^m8ghnk1>^Q|`iQi& zmV{m2>CT6_!Zu%Wj3+M;8yyjIz!r(1XTP1*mpS$1?k|Z9y#rsmWb{MiNMaNgvb-*u zPSLvHU^S4!7M}B(6MFj##+?FL@QO7!k+1zZSki z>>neu2;M9qk?mbt%*$Wruk+YIfyHo{zWkuq^}~;RT=foAYq(ulP)$hyDGf#VjnN?= zhB~OEdGf!dRmJo=4}tP<1#5DNl*{NZQQwJ8U@ZcLxkyj%p&(^-67bL%F^*m1oAV*Tl=I6nd}#ODJ$+bB9}F)+=X2W{puplIP@ z1O&8|pcJ!zzhzF~pWqR)Q$Nosp|K;x03&{Ro23NY1<%0bz7YU(BAhsoe&Q_6YLw&v z{uB(%&~1IPoXv*Fowq|>qf^CY$>Ha?u~Z7I!)yA}=GpBRd004l(;sn?VbbeKcRx5( zj6~+sO^Y35UhFEeXS?s|4Lpeztsq#q1|=`BGBuMaKpd)$)~|j!J)P?SO0>2vPo0;h zIJplxJuYug)1(2`g{|aBDlJ!<7m5}XrvG9-j~=}j8%3}d5p6zQM%#Piwr$vf$m0wJ zs{o8Nh=pUs#Jz*yT_*&#UUknC&?koALIr_F8asux&GpNK&q+k|%wVWNfIjmwi>r#7 zf;WQ)6SNYnG4tSeFRgfvEKNwOm;CuWU-c?$iv30hw))fMU{Ls;53ks{9?V(LT`}s@ zD@c`$1(Cq=!`J&ax6e>xS{o z&?La2h-DL$wA3e4DE`FKC4eWFL2ci582g?Ii9S7&u3os-LMKc~oF#TAiRztV2a^9| zZkHjHlin1MVzmk1I}-K^K4H?CGhG=PU``@*O3}hnTfAk5V(o>wRm4*A1kmg^u;cGJL2V#6(K8|i=GBrl6m94Teldpdk^Xg zgTD&geqH${2zI*Mci{<)-x_GrL>SiIFNWKt-p#Rq>5(u-94996Ncv(&5Fjg>x}i84 z3QubUb#^RGhj+QN@h+c+Cl%9uFe%gu9QVElrrjg>nWt1>x90yYR0(-cG7=ZAZXcF} z(^_DT(#JGS;Fs-fW$jo5m|hwv_IYwuKm(a~o5?g8F$IBz0vab02_QjVqyRYPcsG4p zTCm*YcP;n&M)Su3Dx+xl#GXR2R)%j;z+*uA>OG3-heX_XER98?eWm87rd%9Ti7 zu0*`A{n3oTznnf$IOho#n}on{u@0qjUwwCZux_#S2?~R^C|DahgaFL8uLGtbpk%d= zSh)l(Cb(5$nppPds~?Dq!X-IMb&5T)?2*7RG$v0(5#yGkr?d=$-xinP=b^77kL2u8 zd-i3=n;%3NrV@J6)!?x$%w#g@q-ehdT>&X2JkN!7&0cr4qLFmVh$w_FAVV2j%j!;G zIKs~`1!qDtInmF}MSlH%wE$mxt)6t&ZwB3e@JSv=m^Fz51)@rf1>@xg)g!3a^D)My$sa6QC%AV4(y0^fkxRNuNe_#%dz4U zW+!kaVre2==sk`)u*Lx11urcYHX`Z8O3|@#z>dQoBH<8>A1@DutHz{BhX7^Y>ITN; zxze==y*DS2mtGa7aafsFf>D^>e4ki@GhVg^K^Ur+?;<+)>9hUG@~p46rYTv+0;QSP z5n=4!o*{c|Qd|gSfb|ac?%HH3>}W)qghzdzqMb-{qoN`(Zm~z0S3wz5LJ$-P3?E$2 zEq~zQ4IleaU1u`G-j9tv4r%)tJ`abHu|ttVj`AzpJPw6g#aqR74GXivPGcQ0(6XfbeA&h07GOP3h||q0OL6P81o%qhZTZbzz#td^P+@G zhxzBnW1h}}a)OK50GNQ*g4%5yTB)ipkmqpSIQwSk;S`j%N}E@XW!|_%Y43 zjHD#k8dz57pWz=)alILBNvd@4@;YChNMW!RWw2OpN#OZb<~=>du7H|HLO7)G9YQT6 z*{p5)_Acr7YH}W+6|}IZ;ek^G?I~U=O&|D}Kj)7H3*Eqa#L_VCqp&t-ViJuCYV$Ezudaf3Ca|SjR!&9V%w$>`+IUH;%DG1XFFn!U}ijq93w<|eYNF~TBlq>?O zO82Rn0sbLQaPQs%8W-LHLUGd*DY*!QP}58VQD0aEmvZDtY1_H?;SU!;?-7v@P6^*r z)lVbnBls7L-VxK1R?ma_6IHU+r1j$m{+qGy-0|X#2`64IDUnUcmUi_Y8!o(g#C|dO z3m5VRhF?&Il*hurKx|AQQv(DKoa7%xRt(=C0CsD9t^VUi>uY{q&GnT@1MU($p0L$yc9~yIFBnZ{5&o5FrO9FytLyMpWxQraI z*-A#Sk`kbT!nztyA(#P?&V=y*KME)>bL5BRC6j&MDul1K_<_p~f+7&T5?#H}2{@hY z7}bzh*Z2c=)EiEY>H-%x`Rl{~tfrvufXsOyq8f#~%)*0!Y~FMn;GI+Xg3;uaumnvM za!fD^cQFQTA|28`ilaa}C*BGlvQlhkksY4xWO9A+InB6hrPfEruQe+<^{sIh<9lQ{ z(Jm6~!yB{{$`X*k7p+pq{EdIISjuoPHwt!H1RGF(pglyuiO!U%FnkI&sPiv!X2r_3 zy=yb6Z7D{544ul`rdU5bCN|QEf&@+uZwe#%G8Qn1Rqza40PaCGhk zsKFP30tmHzs4D-039pzOJ-Z0#6et$$dL@ZP4D{WMz4#l@^*rqfJ#W7;%XjTO&Bo4} zYrmUq!cShjh?0N~N@?yunv^QUVK-wHvyq^~sy#sc6}D?vuLI>Er5Yd+c#n#e5_^~4 zAalb1S~Sn2Wf6)f*aXDj1<&U^oYB2|_m51~E~Vacy60{X&$>_%Xy_yWTOT2}$cMyd zUiSnjBjpo1zFV9lF_;QuyLGJ)NZOqe+Yso!W?lEAyCqQWAGy#B$_$_Fo?9{nbdSdL z;+mAU;{YWtG_)0i>>A1tWTc)nmI8V=4(+SIkr4aG+S~=x7B0X-i9y*wMgx^<7OP!pcEk%q{<1c% zjH4(-J4;$Cl#(=>L^Xs2ilirL#h@k>rSB>8YkjQ!;t>cnbm`ZfN>8E@@B)B!O7Lc{ z%C<-vmN+mwtaAa>ztk54{EBN2d$bEN9A_H<$Qe$Y54ZM_c7}v>m(i3o^e{q3WYkpN zF?Rr4ODeD(RP5foM$45C#FcP!F*Q6Jjhc1E%fNy4jXqZszT^G*R zqPl?t7bTr|sI^K@DTHiE5~3sz9uXiSlHhuvB}bKlPnw@=Vkh@#x+n_Ep)=IgfuO?5J|5en8h5`Q6| zp!pXyDSsv9rz@6hh?oktp5mrZ*<5_In2I5K{^AXxe1Fn`?H)&6_z_Iq`_ToOnA~i? zU$SAx)P*n;F$W$WQ!1RIXjk3AOgc(geq(P_gl#;WUr?Bu5bVH47c(7)8 zaHsRLLPJ9Il}$9Qmlx(8o5mR2@>1ja!|$ibh9tFQo%1+%_3`5$hpHx2l~kQA8-=~< zy|S_+U=h~cchDH;-E59Dzd1vCQgi&}DvQ|UlZB5SJ&I|O>HlRfLGw{V+sFJ!T0bvR zxn#F1KM6b!dEVQsNRr0M3uy^5ImH13ia)!ZK6PsEb<@L#4=a8a)(8ew)=-29j-uaB z>ow7T<3@gJ#M6U-U~}6n)uxC#3d};3Nc<5Xb>gTQg~=GV{7(HF)`6yy>eaQir!T%; zRMZc!SAx6!5bTRdxsO68#Jy^2ZZ2TWE|qQAwr$(g-q>sQqT2a>vttkW!GmXjS@NL= zqlt@tH3N|)Nyhif>%Bm^Lb+D*_AP3Om`RZ?IdV(fF{m-nO=3T7^CA`BPUZl5T7U!> z`(LHA7dMEly-z;2+_-T={K^}fH_7z&oqQoZJ(>By`n_+@T+PfB$PNOaDbxu9FaxjX zSWY#bNF%jt-~>fuV`G5BRaR?izkC605IB-(d|X=CZF>-cr$bOsP`{Dq5YdaytEI&U z&R%S^#+&e;5!~>ROKx0$5mNiAXO}Kr$jpKgqMi`F0qng5G1ijoBGejzfq_VGz(EDI zDk|`zq66p=z9EQVOKwO+gdB}bY=T5mVbA{8nlYQ0jzqJephptLyAc>2`7Gq=^$3ncXC|lJI5%tSad^mJ`W@ucyiF z=EM8<-y=D$D7!J4-0cb|adA;2jWNL40auuZO7iGOe*EHwHAY586x7>cCItmsk~xNt zOX)V%JE9x+5mtwZJapja9eePV3QG?rPy*Zk_zBJ4xhy8q&MwY=NIx+oKsX-eSgGa2=o zX1DHBVJPXK;=_kKWQts#Xv=ohW9@B?IkeaGZLhR8mqqjB0|yNnmh&LeJaLv+?VGXa zLS1jSym45|(@}E@$C2^l$BQCA^N7R&uO7_b6T4yDnh!{SphS&fHq0jHBgMo~(A2Jd zbp?`5-ecm2G})8UxYet*fhQ|~h+_;ZkYbeMDjyRXnK5+t?%gz|7x8icwg)CGCr{S< z^ytw81}gmAxpO0bcko98#K4PVrniOJs)+Ac)9n_m67x?R23;_QRNrcoQUT-$RBao0 zgVNzkZLM&~5)w@8k#QJwP#tXIE*b&DCbCs(^j9#TjA~n~aA@z#Y4FsN;-+uzv{PeTYQ`6lnmSm-;0K-=2tz zDy8Rfk^orpBP zd|jeS2M_Al?ZSQzlJJfNG<$Y5nVWGBOm3KA_yW0?fyGO-8A1ie$osmJ{$%Ne1<%MG z^f}QdX8Ts2vS&uNBYPR*zUPxHUq6o7=_lxV8c|`)(yv$)ibpm#^lNKtqcj(KWyBJF zF(DD$2=7jqux6}lkd=N7FCaXTY5&KOia%6I##$V8TaXi^!rF^v@KhH!RS_Jh7q^IF zf#fdyh`qh*%RZhP9%}SvGtbE*laUBI&G3+$`t-;w$U!_8uzB;Rg5(_^>}Z-|PfnjX>3b(q zp2TaIZqe=E2$%pSpPg8!6#g6WpkS58{fV}CR4DEhN;qcaM56NZ^OL!&ye}~#B4YZ- z#jOVLer)k%Ol#!#=DTKVy}Y7oeI-Y|!9lyVX>HM)5nG+8$#aIv$SXJ#VewzcKfZEa z)QDIA{PVz7Qh$}Fsq2{+E7>Hsb9$68poSqRRqsHz`0eLUHC+S%8iiNYTax z^X94NeF&NJh9oYMtfQ0Dxf;!#jak4TfKeUwFUiiEKmR%DKuCrKw1Hp^!f>Cd@=`Cp zA7Ci;SXX3r+_l&l!NDixj|I2{;8RT(Vm(dc<mpcv!vox_^(2K z-BFcI{xoVOB_(R>(}I5hJk%fEzWvI`4iWDvDz;{1WUT6!ad@1HaqG8^{AZivj&@sx z4jG(zc2=y~G6o2zIAf+_tW2~uK^|c77G*KhU&p?jl#lgH+GsR-G@BAeC9LZ(N4@;* z(cE{Xv=XJIr3SnAT;3U^)H3JHj(e(YPhPy3+u7Ne)qa_s>9B9pa=^>F&^2k_*aLD# zVWZaR(9L$Z@y`9bj$1p5C>fTa^>KFMEUVC#oM{^uJfg8H$_dK3bJ(DNB)qdwdjFD8 zuBs#ri#_u|L1Jk=6DvK$G+jb1l1Y3aHB;}JuQOe@?sB&de*sFIIpc1P7KqTbD-o5umM{kXZ=}8JDa6+Jt z>X{~aFqK0mr0mAVIRpAXKxD{E1GiWHzHW^ry(METxJ~tS*ZtrI*`2KhOI* zv%@cq@~xnyOWTLoD)k>IV=6K7nkD<>@(}AS|I13wUx7es{ndyXb;6#`+j7-SMmejC zmF6V6GOhSZ?2g6g3VxoESj?c$f?i*KXQzUwL21st8FR!JT?yRyh;le{hGsp7J+K|3uVro^#>2M!RNVgkHtsW*_LTnH%;fX;kN3KC>+l6G{^iH6%}c_~ z5~C(fS54HBEaFJ(6@9F!iI{PsV}HrdfKW-oQ=G{VI>$LqPEQ#9k<|YZk2LAPv$(Fx31{yE_?c1?P?F7N?ukPE z)62BtZrk3TF@riS!id#0SI-8w&H;isVghNU- zgSQ7r_h`t!0|J+S`n2KpSY(}IK?6PfKGT?Zg@O%%*!@$j3PUZ1eEQ5^P#l@0`P6|E z%*igaG>Ct~^p#oX;?UjlA&nCBub_#KHmvpY`*(*eg@&J62bVQ#SjD0PRjbP6l#rj( z=U8vnSf#FT_3HZC{#?uY9}w=$PH+psIFIL3t?^+QiDON!e5YOnQoje((TX?IC&6_MUe(+;ZTgd<&9UaQQ z%q}#t&>?TvJq>2=OwfqXg|KMv;c)r){W>YSMM_^;&jy1&tC{r_EdJAh1>L$@wJ+#Y zVG(!ua$76K!mSeq5YEoZg6>5v#LVyB*x0)$0VUK<#A}}Ewf+mu zbg=QSoWN)bmk}L^;Kea;!4K-JD4?W!2*cexpWi@T0p`ned%(!vmWNPb?e^|GlV)#Js{OfZ| za*AJ8QCp+OI?!pLu5)mJI^)aDQizy#b;v#P`wZ0#<%aDbD41N(kMWxt&PH_w-T<+kY*k7{7B|T%gcWACxYFe1&jH=lAixc^vO-)VZlnV}N_ww>SUF%6S zQyXyo*8+>nNw3v{)h3`O?xmzu+_#PUY4GAq#t51ukbj{%5qsL4osV`sG?9=0^{eNP z8zH&Fx{l4=qI!p~x$~6XZL78@b~s7?u4huINHZwnuIi#!p#{e=Px|)lyv*;#brTLH zB!FWox}@?+1n-N=a15+K(MNw!7qR|?f3Pg7fFh%GzO%xjvN@W%?|v@oQqVs_SMA`K zuV21=nU(s2oNnbWBr}g}CU5y8IspFruzvQm*!)|hLJ;V`~H4-bWV+vOAGI$bz8 zwk^dUQb4OnrJ!V?Fauh<_Cmno(TkS`M6{2Rj0YT8F*o?m+qc;-y!uC4RDWm8?3;iZ zP%p;CDbbRhv_Gx>w7`E7LllZ~uYUJCxUEic(*kBkr)@Ct9w;-ValJ*Su7cs@TFKyb z^Zfaod24rm*u5?y`{bPtA4+>a#h>mfuzJq~tuH+kZYI!*k*K3{(ie=Z|5H?4EcOrb z#(~pU#J2U$l}fQAY8}HE0-~XS6=*74+noZ8j}2aw%9AMH@5v* zxxZXTsOKxFm@eiS_Oa+mU{`FsJ^?EZ(E$}+ZCetNkdW}th?mw2i7y8nd*90_T6K@2 z3@J19lGsEHw6J&3LGIGNycnk57H!Y+7>Q6j_Pyf!UIBexzTa*f6UW4rdi7nuk$ob* z+#~qi5=r|_XvZV*@g6@0OMLQ=sM#6f%GvWg`G{F(>OwvLz;x9wkAZ8Si6KN))xTj} z$*%+}#@Z?`Y62}K*L+pm^l$Z8gghu$*B}2?bwWV08^g-@=b`FacH%E! z4d6tf;a+cy3H1D1=^XbXYYA;?GKM4&3#**GCj}cZSZ3=UY*0G zg?E~D{G&o5)$dhRS6d;fiAhIY{N{C6U@u)d^D7gKLl^gwe?|Lw@lgeRp`y!0-lVw6 zuwrK&GjCBUh`t{UbuJ4}%FC5Ag=KO+Na?|*`>Op^!b=?Ljh`15UPYP~5g19(7vh)f zjj>?l_g4>0(YPHX9nGngf zo+Vc1Td$^6Tz^?|4waP{8l|Em!+o!>k3R3fnSK@N2b+o;$}t^CvN&JdRAml;W-W z9gHJtYqek9LBIx1wWjBgn}l0+`^4PCy<3_$s+LCh?~r_Rma_bk$E)4+Qi$O!sb68+ z#Z&Gv`fb%m_S3J80RaJcy*Hv&zdUKzXTCB-QtiQ)0_iz19} z(gg!}$2i`-rH)ub6h{n93RU4>W5~UF7qrS>o_S~XWO@3?C|~;bmh%(JhEtalAm-iQ z);=mM&ow%aiH0shz6$(Q{NaqkXsa#vSTzM;v@j=dG6*kXD#@u+rwT_NR3@Tj75wVwU5bocL`UFQ=!+INHA~JU)GZC>$9vr1%^+(ziWF_3PKK7iVd~ z)4?RCBO5rHxUFmdfK>Ql8yy9A{~qX|DWz~X>=~qqsE-o%9mj1#ow;KM5m0D{>Kz(r z=@Cy}fQ`p7LF4EK3r%xPq$Q6RJw&~)&Jhwzcdz1`%SB2-pvDXa`ZGvBnEZ?dVap>2 z1q{8g-fB(dERWhosV)ydk1`33oX06F#E6UAG2;aE!E0ad1$$`&M+tf+YW*s=4giu8n&c!*3-5e!K! zFjn2M-YH;5nQ#H2FkQCQde=<6(-to-{cxD7izpQp7pD=K`_a79GY3U_RAylF(V7f9 z_dqyi;QO$JT1mjLP~isV$6W$wxn?_U*pT^T>Dib)-*0a3b;XjX8qniq zUGhmi`Qfy3rxsxId+(B?+!1B03K89u!=C7!=Z|3KVIWa!;hPQ6Lbge56`k9?M z_q@a?;5sZ5fj=|yLa}O+HlV1tp3!So_Yt=-W~k4d2>L3@bGZUKeQC}hX6TQ zl=O_JG)@4#)AIj#tS&e^|$^+I~dYqWb=) zlht0!m+ZM2ku%^}aznB8rcJR{nI~7ySwE^*7h0%u_cLCT?xN1i{s?W43u^hD&!0Yi z6lgaYV&~+0bW!N=`_qMM5nuf#P934%dw9HK-I>xw{1pH!5qo_y_Z~GjJAhLjb;o0; zWZqy-k&InUE1c7X0k2d0=Gt&Uz(;J1UDX0;-hzJeru(9u<{Y-{$;$^?g;hGksrzxa}wIXOXV7kb@I-~eJJDzJG$ zu{SA}{XWawmte<3GJ|{8-ZM|x?G$yrectVA_Qj;dVwVbgIVwJJy8M?fI~(*U6WJ!u zx8sic4TlyDo2A^GY;iMj^*azvELXUc-fg?R%=ebVnxe1H&fV{piEaTAD{)Xq5k4Ym z)2G#Dw+rLgi4z-oTJ>%2*_sl&*VMuyfIdxIyESRWtD!A#iYDjmX59BYH9{Nv`9B@L z8=At8vZVnn!ksRX3ugOF*U~C|dDls65LekJvRhB*-r&zd?k--kgz%F_r>+OhDz0yx zQ9t1AgRJ%|Q!Y+v{>uA5I#}l%2-}Oz1Q)1+Orp4x9kOL__gb&8;MfuA%~{qW{EIEA zF?|+3mtm`;N3IltM@;NP>5eF8VY?Qfs03uXG8L8#q2CW=vOp$d+HGyFKHpwHcE;Y{ z@eA*3j=LP+XQ%a@#(M`2bVT@}qN;j;-7oCO5+W@d%?9)jPEqo^>|IuHzMX0iF@@gCtPpJ`oHiO9`3=h_GUSCyJ6<)A~{EML&D2wc~4o~~L z!`gVw_hUPJram!4*1@@KjSvG;jl=m0)qw2bLWU_0cvt$nB%yJvf5v9JA5cM~wRRm8 zCDtQf?CvvW%y32ICM1H}Tka3F*^LI^?C~P-ZLT??Zbhk`$JlKUGm`DXNbM83J0O<} z+Epri9}TxL_MzuFfkU5Zz2CpQZhe3;_KD;`gLx*4Akd2@@rE=~WXy(b}9A;FB9 zTuL0*9LxrfB6~2=3s#4RMLUhbKR3o)JNV5NUnkAP()8vV7foi|gv$;K)2+94w^_VY zm=Z>fy0*e6K>6yu(@!kGJVzqNP&TOKwF#`6fi`=40%6i1aL2l+rD;9@wP9(0?K zHEFxlr|0ESa410lp%kscFB4z>#%3&6D$rcQU4`gOhjTqWvsyq>CXvK|<3&v9V<<`o zrCxv7>{iwm1f8#5O|21z;DCIiQy87%p43h~bW?t)+_HwUsfrT|AH8*SaA;ldj5ypC zfgYEY2kpqe(w#HnrHs#^&xb(_5rz*_Sh8o&3_wImpq->k40zN@rlL}N*drrF*9T#@ zi;8~jqQ>d{s<7*o(g5GbVN(ieiutpdO>4)GznOY@^nz(`UjKdu<78rDa#($o9;(6l zRm4J2*7Oz0GqAN7HU$An0lPvv)9hsGXxqyK)MwvS{@vT*E8_i*J#){;TU!V=jPR2+ z=0=}Bd{8TYyIp_m8R^nct`N$^Y^pL>UDZD>GIMa0gi1J?x!`LXKj{U36~X`NrRSWjwPdCnG>iHXAlGp{^f z;On%;4U35|xWIpIZ}`=z*HWg7(+JuW{O;8|TJq)mZ`=3Mq)FVa`-Rw1oY)9_Cp&*x zC2)TlL?Xlx7U8*vh*z69G4_djr_P-dtJ_wY{zy=xdjtb`rYDYfF&&eu0xd9g$hc&W zbnU){I_OAf7gOq!?^>)$&Q*ynaDN^BGtBI-@rAA2uUdm&|A4pgzj$#7k|pFDvYm$~ zT@r;5drMHVmkT)F&TBax5lqmPbERfCGWJD!$-Ev~dh@Dvue6y8ZSn={hvpTo?VI#dJRadXWM}mCyklK(-Jeffk#vUWG?gNR|l|ititY z(L_Vz)zc<4Jx{>Z0q|_&0LvVpuvflQdVTsmRS}CryOFU`~=o>Q@R zVV2pA91~avKCqbZCiqv+L!Wc`7?Us-X}Q$@3lZD`+DOs`P0suj7qm49EtzZ8v9N4F z){1g7zmXNtd@}U=YIGo2JVmbs_%PkQX{^X1`Z9dk$X+jJ5qik&J!nRIW4^jY*|`zV z%+g&&@$qJi+5Qm- z6o&uy94=c#Cx%XZ$0*qi>S4S7!X!`XQgsB#ec69iN8KR27xK=@ckUdCS0|`8d7k5g+N-r|DbNW zO$0`8*bRWBVB`I@4YF1zp6-i8o-+0mDZPRv-jzCa1kgXI%<(g4UVjd7$@@6I$&MXI zPZ*Rml1b@2+2>V*aK@p)sX}q{C@pQy@G+3CsPoTu&)Kn2A@@qwVUKmGrCOXu?}4>M zlqwDf#EeuSm*~$CDu``aelI=l@TGP@BRBML@Cr75g{h@)OO%4=5zL(iGH-mbR;Rh{ z$WK?Zq8AWw9OEd}M=@O%*s`1xdwavV*LQ-Sy!5EJCE{Fq3>l{)OVIT%q=KL;zwO@6 zc=UA0|8l${9z3|fF!}}GpGb%p6cHqnrP$r0c9dBiPU4+=(pz=anDk9-V9!w zx@=c?#TZI!v~2yDLxBxM80H{E%IoXJDcM3~&N;OUoSjnF$(zT81hJ9m765iI-KFB@ zs(h|m&D>wRBD(bRe7E=6)4z6u40{E>|N8Y9ID>Emr2a7*{EWmb7|zfY=~to^{Pk7e zski>D%>b<2)XCh|_L~3D#K8_^%W;{I8d&l1qo9sO76kj65t>t{--5@Zt4(|SSV-iw zZCOmDnhJbOF>W>BCUa@!AwGP3Is-*)M(Z+d`t;1Ngup%dG*fc$mH7BxAhJes`T6<7 z?w`h@L-y2ndlBTbCE@p#SdDqNz9484nGU?~AMCmJU%vdyr+a&Mrz*Kok^CMcJx2eW z1;?TbT0Hu+7#V&tysK{)MsC1!(b$SmI_^N*^n_PlZYi@J;X>F~6H%Iq)LS8iLuA0y zV*|%`|01*>TxqZ?oRuRcM?dPo_JkIDR38Zwmo5< z5vm^g%$)`kU#p*ycGZ}psczQZwQE_nM#-w%;aXYSpb zf&iD=R2Wov<>-e6vg) zVa);8Oh-V}y)-3ZmCFm?X23eTj&q%p(HqUBzD64}oGjFpoBQtxN7TKX_c<<`G?a#A+oK6Xy>a?(ARdmw=+AdoD-1G42~Mxn03S8~1^~b7;U44l z9Hr?VUwjCCSl_pDh=$GLz;>%0JI03&liCn<>{yR22}*#ou>aj115#3S5a_>@0H_)O z_Q6?1!KM#_MYPiKy4u&x@G9X{MAS$x?>TIKZ<4as!onD5OSry&7y-PBeM*mH5o}?C zid1BF5YqtjWMr5ivZZKQU_GZ(;aBHJ_TZC0DJcmdOPmx@By=JaijMJc=is2A3~b5D zvNEZR-l)-=I&;L0zjxw~aSvo>XAj5a&euXN^a!T*vW4;d`M1k+Q2K?EKh6cN_PBzI#;rn!vWK7cWGb6^a>=GYb6q0wMnP>OEygVSZSD z86xMC&A_u1+4=SL##r7d9$Acw9!(=$)$4uY_ok*#_{C3dMA?3__q;44!m{@`)7e0L zoqPDQBmi6pRRUKm!$ps<*61_DI6oRnIBhjxDPuqA#@(%dWq-#6JynOj<+!4ha zg@-(z2wx+_4uy*lDOlXx3^^PcG(Jk5uTS<+MkwD+(i;GuYeOA3D)y1 zMt%DR)jEIB96<@r8l1>w0swHi2kMF}ER;SFjtFYRv=`LU$hf_sM;Cg#y1PSuJO;D0 zYibqEGp!n^8y}(MxJL8t*39e}&mlEJ`1(Ji^BUbN!h?r>dwA=LhcT37LW|BVH7q`a z1BF~P(lba22XHd@@bzY)ihW8$SGilXu!20|Y=b!bi2Tmwl>IOLrqBdmkJlmhgIZh~ zOsnnP-Q5{@NNn~XT(4>>DswJuq!+ub_1oERo>bLPjc@nC|CX7V9Z5`_%87~qUn9>P z85N!&3_jyinm=?5-I`0ZQ3e zP8E>o)aQ3g%PpKTVni!{7aPa%)A_I*?8N7BW;Blo2t#v|oa<0^`gv~?J~(D;GbxRv z>&Jqy-T&wQB9EARu_ZQNU%!6+5@jeu$hO*UH2G&jdBCrWAgsl>(1cWr_PBI!v-$sx zQ=;8^k9$MdHU!v<_yfNBAcHAt_i#_$$1u3mrw}zAH1{N4yWKYobw`KH; z7pNW3oA}SjKfQB=Nsg@>To4)NpabZ@DF{664+_CcMkXlxQEtDkgY2v#w3OJKzmXKi zkMd%*A%I^rt;?W+15tTk_w2)f9NEVmJ9gYu9EG$F=tK$o6WdOdgL!!kSM1apO8@-4 zxj!_&hJk3aP$=Mfi=&z#zJ3+HNWfvo`r@-3UFDUPj|HXJgT8++j`z66B@}r=)=JfU z>O+oy;nOVN5c7^16}Y$lr8Q<=UMMB+x^?Slf7AJo+bUZ@#vSmxefgp^O}mw0V2CLr z#)0E(#MvsOf+#Bxbi@q!Z5-Xufg8mQd!!KYQSG-eCCi%oJMp7SqMBu@U)Zdz!3)|} zTw4w``=Z}EKq=vUIzM9PYhe>&J(7W>_^}Y()l@%1K>@}Ct#PJ^ z@Dzz(3z32AGeT=yFQikL`kY(OD}2I=2_CLX{W!|xN(sA}K)3>|IAiE03)0HG#I zP_d88QZQYKKs{ku1b}{jrlT*d)`7H#=uJdbygFeEN{il_g*|&84f`%2Bd$bLWuAw! zgGA~PRTreC5Q_4L(W61Jej>qVRNQBXUVa0y)hcCyGM5kk%wIe23hBWZ%SgB)=UNz0 zQ8)AK;!ptx-&5?DBze;f(G@M6fo=iM@f)!bvFJC1CJgjW?cYN!iz>J5qC((FpRfIN z8IcM~2=bo@qF5LB{F?SCl;f2cIz~{15b@5eiVP~b2!)+LKd%@N=C-8w8Z)!`*lNk& z(&RN$6QU*!?$=lH%W6`ln?iO0E#9#4 z*=+S3M9(aM|HN*qZ+kOrJoQ66r}CvxB{7MW5Tj?Yeh=dkxUoq{Fk;ULeG-#WRwfnO zGBib)@lXA}%E#wvJ@aTl>%8rH>%k5nzz2zZa@@_{pN2u%&)AdQ0Aqb1Z8j&r+lLbd zmk0O#;TWZVzhChCb)U&j5)ttD7hhzoYB|WfNBZh165TfZ2w<*kQmS;cXnL3d2O4ENet%!!?XqFE7U} z>4*DOm70h{eg2?^Hye*l-0;YIcr%$yoK*`p5_V0cp$Loomw?!FQY-X)O1ORLlwpm; z(mG2*N{u7cR@(8QRc%Zmqtqa!#=SwzJsUW33BnTR3_*S3( z#}}*m#~y#O2XI>yG5X=59H)vsr%ar90>$}i4du~)#+xVUvAJ+th+dVaXcR1lN{JXc zVJV}`LqmGGbo-2%4W?gyoEe>d&~qQ56C&6Q=tOYKH_oN>#`PdV?;j#v=IQe{LIEGlh-lK$DEz^`)h!KehuZT?}L-Dq}77RBO3&XTuHy7J=2! zI$*M(-0|@DBGMlGAKC1O@4vNGI??sLnYq>7qFQ^2|M;1GH+v{=6AiQ*LQ6M-@X?k9 zmBjFUKz|qt#C5Yxa{5MuJFLdlx2+Bre+31Yg<1CG?+-k%=orWCWBC;LBx4!ENeXc~ zC+5Np8$v&fmsdMHRC$D}%&{S}j6pp4nZo(ZBe_>5T16@+H?rpY;mh;>(9>D2?cS{P zeTuCq_u8#C7d{OC-lM@=>i=s2#PQF(BlsxM&@xwvE8{eEckuQ3^ubidScd_rg%RJa z2RcRisWg#rh{}HfAN=FKzQ5bp(TlRnRe_{<$rrUc7eq8Ahv zlU1umUNm}hT31gJLM9lD1yDGh@Ezfsy3(7kTN@Z0A|Bv*v%hfb1o_uy#7GSZ2W>>yG8G$r?lx zQpmAAKBZBS_Azipfn^92D&FTzHfz~!SvRH7mX^ze$`LT>vDKuBl!vsmsNBQ_e`oqu z&Sv%#&2dKp&R^g%qZU;oYltFEguZtjfBzwbo-UX@H&nW1EIH%)c6QN=oc69*ae8f3 z;kyB-N~mx|1_s}Fg!a<_fi%5Nd2SnV_7|lhRbYLb=Z9v}we+L)Jdn)1S0O z0JBv|j@RO^apt{%DhOWW!RC0LvS!`79n0F+tyr-se9ZQ$8+V#IoXfA_bOf~6@Xp(} z;uW?OL_&Z(v!e<`o(mw<_?vr^Ih=I)x^zuKHhgrD_R8O?l>89{5NVb%PI6L%R4yH) zIBuwT7_K_wcXjyf1$Sb7_i_x*L5)cP&Ikq(3Pg=<^J8JfwpmaDxgPiaVhjaO8Nb(q zwF1?-!J)^J^P?@hx|BHeQ|uj|6cLN44R)rT5*_Mrm?^i5Nn(LLh)5)QMuHqPqGS99@f ze)|0;GQBtxg(5%DaQ{JHOEEA`|JY!=o^yL{PL`*6<5Hrb0r0yQAR!{CYx~?4?i60z zVkUHLJU~YP)(Bd5KJF638Fr2$YxIKFHN{1MB`*e0QO|{CjX{$zGICHCJ-f33{`$4ZSB8G)~ zd&I5~Q^Im`j$xhihrY77(>lepI~}_auZ!(WKJ9vpyCQH8oAb-ixt&q{2s_cL3Xd;$ z#8?j@+$JT$in^T=N;DscC)ZDSBQ=8PHUOOoP)}CeVB)Ulwi986=a3;eb9?Bdb8#6$5Yd5Plr^lJMz#S8p_M20PZ<&RWXAOAB1VUDW>oTM z`UUX@Z+zClh(jDM@IRC!+-MJx%}J+gofkGw+z}iZBJKX?&j^IIR9k{ZVXKo#OikP$ zS*bLp58F56KsAANk(Fk!$XBK8PPm~M--kyoJg8hPgIE7_tw@)JOJ83(*7wG5Rz%IK zYdu9a%bQ!CWVkbRwF!NpcnDlm>FiSj_A(qbiZ5gu=a%2SBR%-T z?iK%_6B6O|LYva3&r{(<<=)hX?xGHFYfk;UiLzHy)-1)sgp9q>r@O#o!F|!ryBu9# zXa?@Z{XGzfl8UYp;oX7E{~Tsbb#ZUMhW)nXFJt-{!l@!Wb;icqPZZ~Y!mOW$Dx69O zz3(craX{uSA-3SQ6WKWIgNT6bMkCg&%Qr*N(gLH#s*z_>dW9Gz!#9gS9tHsn+6})T zSScQnkUr~&->YdYJ|Yut4};!q$jr6UEmF>3#nYrG-!$(FJ0=?hq~Y5)K{i561QTWc zgN${kZ`I@h(K|;4YbTrSj(OU99H*Z6!4xZkp%i%rq%iPE?j|M8SAAAwFQ2*yH!?*# zOPza1+QRK)@ADxcd|`F=$)coN?E2~U?gz)*Y zzP_{kcR@m-w>ix0y)SyPnz?$S90+mz{OU0J=j0@_y$HFxjxq5gHVI85cmT2wKW)kv)9nAQhp9oRqe=} zV=s{i5{&O7~vg2@|sEizGxKd=P zoH%}5IkQ5<(4(NsO`jNZ?b-&q(A!5)_&q<8p{T$^RXu zu;Sk>fLaMd{lCa%KJ)D^?zELcnsa{oW5g?B^NTVJtrU`77K_M;faIsGH9kKLnOX3j z;j>pZMswp4*}7a;j;)F-nJpApnADrYa2R(5Rw59Q$6}I+&hGeh$NGY(pfC0daNj?9 zGH+%pqDj%YlEm_f#Xfw>l-{@t!IM2epx|qWpqC-D)GV6jKbG#9weWG@*Zjz1#U>!F zV)6pZd3f)WL>vJ3iZm2)E>UTc34!14X6D=Xh#zo5J;eel=ImS!bOROQ5XZM}5W4s{ z*^6l2hqc*K7qet!L7^rY%@or&!oy=Z!NA=`E4TVhK2#&t;0U{YB=rpV{Q(EP8F;Ov`rX^jNmn>_GuoT#ri&M2R?fKf8nFNN^k@E zFmdAb{Jsf;9QZIyegykW;A3KY0wY11aB82K<5nBP$kBOZKL~>(1RYoBaRSjeS^>?u zFf(;E=X_CbEa77aU5B``i1+Mk;p&M!?srq2UuxuJT4gOznBB?Mj}9d*tN8qRKEVXUXDBmFn6oWd z^z#gNJq((h7ah^#j3d$`10!5IIUx??boMtZF!&MVfW)#pYh|zcauJ_@H{@RN;W4`q@ z4?-LAp4{0-$n!y3{viWE;+a?6)33#J_#lpA$d0IdAl0!ZbQvCK39A|R`66NJ#IgI zVAXv;>ZgE#(;xAJ{-O2-=2hdUr%dF&#Z!Wyd1rO53Wkv-qfe8vAmX_h2h!#Pk|XXi zrj>}MVd=i-ddq zd)r>L3;ao;9!0z$2Bd*NpDeuo;DLqi#lz#$p&54a&X>`pi)sK(bcdFzS${cF_J~!y zI>~Vqi}cL>D9E)+urW!cZCyPj^Jw#I$G_u7J#{Lb9(?@IOMpS3y~2hlmMkJX$}9#2 zpM2KMzsz&C9mtM|(YmoJRuaUSLqP;vU=Dk9dv*A0f$#P0J0h@wei?SHOu;ztspxRn z?4oC88-N?cl5@EFm57ECBb$O9CZh2WDlVVFyM5`BH`0q^&EV~>Md&zQA6aFD^EWy?iqr(wKx4X zpwW_R%aQyDcvjd&2(6>J#3cs44X0_%r6>TWhZj8=my-o(Atw0X9cYvLtDU6MliA4;egy zF&gl%(*|#)Zqh6)61}cKf50Y-8{u&5ZX8h*k~$XfpTC}e z;sj~em6XGlKq0ry=0>li`^|t@rQeVfq9ZWVml8C8RXmX!!Sg`V_q- z^qbW@(FR{0YQJmO7D*$u^v916P=kRS*C0=aR>n9Q_SU#76@^&Gn7T#sINR#fAnl>c zTuC~Z`KM(gGzhmbr7@DAv1=Z5dR>7`%qk)E>R_9(mgcgZqb^_eMy|EScwvH?Zijt5 z3a*me6n7WS7g6c}z>riR0%6U~&3it6W;IESj6Ebb=m>oZx^`kZLX9Gh1K_gk7-E5b z=%YK!;0j+*@#UketVPlRPky+djZn;6T`MzW$PjieoeA*Z{Bh_1W|X=c7_9#E^>Jb2 zS+&5L!1;=XHltqkyMDV@hfd39uQe^~-TK+2kNJ~FVacPDMyL2zER~wm;h==VWf`wi z-WrlBI%+C`M&kn)e*HSc?Q5%c;KD9yhOh2rm%ILpch`JY3Kf+Q&l)aVm=5=&a(X_g zS|aS8y;n{t-vabxoT2{@PVR$wHja~TAJ0KdaxOR+)NO{=OVTzuKLR)!H!>*;jR*5s zP&<(#?X5S|6OpK#{LqmZ>FK$>I9q3+&uZS{Hrb2i1M|%`Bz{hju`kqeV!9>8B;-!{ zr}HWSBKoWneS$t>WK1Fy-Ngg%+NFKmB`hLj_gJQRb%Jm5x_NUvJK(Cy!D*MCo!g>u z&N(moa#P&a$=Zr_yUC&vr-K-y;ItS&bZ&JO+j2epRlN0>i#c%IKAbUo_Y~D%VS;l0{kg2Qlp@K!n$Ae6V$hFNmOS6fjxuZ>V z>K7&3KFkefnBL23%d-m)AxB{J6zH5KomNX>>Yq-rs}^?e$M{01Vxf+mKzt*)guFhq zj9A>2Cr|E4n`bjMk}V)k##IDM((9t#UZ9$bc-@axiY@k~-32;QbPmeix+-WHPAZc$ z;ixK7hml`4_!+7AU5 z_Ic2Edx|J$eFNRc+-`C_ZGxWWTw5+itud$2#>Pej*>aQ!)_mC}kU>itLlL3NL><X+Ye%s?^4qBy?~QYg?~r#r|0T|e{U*Oi)HZ4RAI zm`+jhf46t&<{bcOz@+_V?VG{(5l`KI8eTl?wVoU|#8t{bs4c%>w|)DSP1}pkj6?c# zs@`qsg`|#;UeF>i)}iwCjga$ePL@2gj-yxZ=<^s^jO?}LN8kr9q5>at}B*UFpjpxs(Ornn)$cQfy0Z&up^!<@oKV;W@lwR0of24HL46jc1|8d zJRG2l9v$XKKOR2j`^e^YD&58l_q=Ynkkm?@-y?R_QQGUgFIjiWej-3)5}ic6{tM|& zoysZMydIa&W^}gj8o{)xJgD2j z=dma6FX?Yehp2)scu03vJ|%H>$!@f+2g!0{GN10y6S*UwuhHp|KC*yI2-!*QNtc+$ zV^$9Qcy=yuHrqucI8$7Ux-H#uA}vyAVeL+WQs5i@y*O(K5c!!;%Kc?8-w%(RWJ~>^ z*{QSHuVs+gpi==5a5&WmBcpkB)g2Q6SAmf4Z(mYK3oL*Xl&(ivym==b7?eXXEuxBr zieSmFLr*L{bE?Gn)N4riftc!1B=%TeoW{iHT)9M^0tjy%uV63Hd|7RbpXj^(hOYGG(I-0w-4e<~&! z^&rFg^`J9@_k5c(dVFw}HF;E>FmO<5oecg%COL_OSM8$gmRVn#^8EZa7t4-Iu{H|3 z9Uvwe3l<_%3qr47?}Djw9x|4KW0`KsXIHYX?o#1%$+Aoqc5D{HDu`wg!g{7qT={Sp zA}CM*r~8mMhFTd4Q( z6OsCCp7q?Z;^_lRdaP7hxp=V$)|m?2l;Phy#dZ!LV(Ph+_c+L8qHzEx!*RBon@CHC zPw9$BxD!Yw!{=euV0~;x93xLi>AkYz;Rz5W7#)FfJgvXd4-cXhe1tt-3y}rX?e3IEvauVX40-^>2`JYGs1jYU((=BH zB>X7JPNIOGtt`L3{Lw*RlSnP?Ol57BqQ(Ap5l)c%u$>)Baq=Fuz$<+&(O8fv?9zRy z@E_8jiAgge)?L2KL28-pAER|LU(AjwIaV#fgaLtyiT=&^Z|Yu;#Uu6b??W2VQTmJ}M=0mSeLpiWVzP~5OId3o!915Gm-#;u!2-KY<{ zrlW~EOy4aIVDO4p`;4vb-%)P*L3DRfYb&+cpa%!iY_k4nA!a|#7}fAm`%OTlXu=i) z5Y*jZS%XUqCcF`|z67b99FRZ%?LcZX12DvJX=gKoz%-la+EqxXPM^iAFqpeeXkEd8 zX*8G6V50~6@3f)n;(TVEUNdNg89uaXEQ1Tq4Tu$bEfx!Tmz|PJ0@#dOCL9A};{k}P z4x;OEVPRjYITCM>g+=Q5j*R_qyRIyCNcZmD1&saEjNqRNBxFyZ7yvH=P$w9wDJBZT z+C<*E(9g(t(0^|fudaLdCy;jv^70aRmj=xOuX+e_F%i0wtZlGHAt4v=gvaY=ofX|W z$1C|6+MedsF8Mmyb{E26TWYK3H~r>(uGvdHBwh&k*R@Lw|D=kXfB58>d^vuFkmZGi z4Mp+L0aEeR8*qOSQcm?{78$^N2EZAcEn7Z+8@OWSN;ZyruVUX&CjeZLpv+t`O4Av; z!;CJciC}Orz3Ve1qrLQD@o+qv#Qk$V`?bYLz$;`HKS!=58-N1D&2fvow8=P#4Uy*S zzC8KvWUwtuex{g~OPeBMQ!1Z!7jd@Vnwxh7ot^!%(y($Q=ikwzN1fkJo*m^u^3(L9 zb?H43M++Zm^9h-`oQ1U~^_lsy_+S~v5O$v!XT-<1)!Rh*gbK+Q|8Ri5G*Ry7dfz63 zP?U}8>RIaT=X0{+1KR$-7NBZw#L^A?G!{Rhk6SkzT2AOEeDZ5{q=+fSVvMEZMcO9N zj1RGU#Ox^1GcW@Mzp{nPb_Jnob?jiOgpf|60++nP3?&|9)|+CZf&Qw}AB zj>h4eeFH`Rhxj7&kd~!Q3ZXud40?N&$&z9sBaBhS(>=r_fQNiT75F$Kqcbz#xDKy&OWC6v8-x!nk<8?DSIfv z#1#+~6+SFTh8HFl3<#d^foaKan$&-h>_sBmovfHQy%A_)`7h^11)JjD(iWv7d?zS; z)4BdvScdEXpx3!-ifQ~})0nlM3x{`Ps_bm#hcw1V;BUk^F{vLRR{Sn8NFKPFwD#&Z z$UA&!sMdmO|Z@J3ZwGm*uq>S|6cL0_;wGb%;i^s%U z($MH7=_$c|2p}ik(qdE7sy?2Xv?)YkQ+-x4x_q)o;Jp~No*cZEuw4sn)b`)hxa4Lhg4WKO0XL4`IGWaS%ov_rnw3l2U zn+K$~NN($X7+RTwkL3Krse5se#2+=;VLdVM-x)6#Tv7Lxs~J7FxC>KIlpBacAUeIAIbW9^fr`5}M-4jDPq<{v9&sB{KsJziuw?A*KyG8dE)!vWTT$k)xc|P1D9kviX(7MJ&oGaYK@4L!*$DF61AK2b$k~ zPE4)w=lac7R=E$mUz+j^r*&tL|I29ut=jE02M~nd2O+6KfyJ|l&$Ht$;B_9x2r{)q zcVw?Hx6CU?RVq;Gf#mE>A$vAXKn%}eZp8n7ivxf(k zqpe8iQsdorh|E9C%?(A9f7Ow{9TJq@&?t02f7?_4`XG8Ci=OYu!7PMa#9~9?dxKMI zFtw`^eT4x}oC%i@_!@-m>wgJ{9-3Ze62%5wQr*O$g2*C_}8Zo22yg z{}7mbfF51U50&1D359!NnGPVBu9{*wLRuk1%G#$qr!oC$Us{#PPGFKd8>^SJwNxI9 zckCUyf`ds&*{BlXp+N0K5;Sjjw7bhe4*jGZKoy+5Lb=J*j821(bq+BWx!p_N1@lx>)oDOq_xmv zWsgDEDyw7xoc6k7<3zoocsPAg=Fq$oFDJizD#_B9aVlAy}kaa@-46th(oX8romQ*cxgvvbsa9*63inTpk?p|PGIQH_W@ zVPTR~SwlGAXLLDh9rrf+lkFN9f!U)^-&r%OP`YH0XO^|Vk?00o5n_FV>YO{WSA`Jf zLb%j4Uop8Dz|mbPh&_UwJcPJTGkHXR=KS6_PnSGgkzf{0AzHII&t=*`;ETE~4tG|> z84I0Iz81O1;wM7?A4^f1QR2dJ!PHA0y3o|SGAk5_R?u-E zxra3WeD0R+`O-PnJSPc&iZ{M+__E}m9g1siM^4@myQ=8tC`B2mR}x=qk}uSclu(*i zaAsWD*d&KC*`k<~lI;d_=N?#7vK*KS!}!6`u+xQ}|7PZ=bTCuZL)xw^FLP<}NG95P zP}eC5{U;uzneC0sgotV1-|_o@brl0psP@)xoXf00eYF6cjxzndvLy>1or|jC}-oZy6-v>B&Nne80{_g{#-$8CEivGm~)-x zOvSwhzwP8buZRDvkrfA1ZEcOt@F@>;LN*pQoQRwATvcnJ9Bnk$y7Qryr%&}0wF>Nr z_0)E%zWdxXeN=IO-@GG@la;>*=!V+#9W?)Qk;}@icj{J4{5dub*RMnmokn>E8>j@m z0Ng0v_@tLabYSw9g3&Wj4~VN1o3l;)m|r)SkG3<&@a-lolsk`UhqIme~fdx?oFOzQDx%ubU=urdTSQ(WuCH`IA)tDIPsEt{cu``s~ z7xLT{CgO2ChZDNWvMgE|523w8ak8VvW9dUNy>+JWCt;DrD^y7qnvrGO5*ub(*E3XEhom&Xz;PZF`EYX zrd`)Q8)n!1!a6QACMMf|+;NXPTnyw*+?)Q}*FshCaPd)F5z9^EacVzp)cxO|qg=nl zW-KOAl2$nXhOvZr$%1329@_xX$G^r7_#4k7BZp`6yX}cgzw~_Z@ufq-%9&25Wo4+F zoDRt+>E7OahLClP5~3~=^Tpd0I*1`3u9P8=9d>kIOyyAF79} zrl0ZaKJ=GZzP-j!&e7bDbz!zNxH`i;^V)^kj)B>(&?u1`zkI%s{QPzLvxQ-Q4+m!S zHqZNbrz*gscvOO!KmNh%a~se6dFyN()2de;F#Ky=@- zwZE=sf)bUCL*>B`@f{}biK2E3JzE$O8olk~_05XF(?vZmtM5>CPq&8rbTs957oLA5 za#Rlj&6ZO(bz+nmA{~PTWfBkhMxrX-h*(PSTjWi33C_|lEGxGsDE{0^TV*W|ysb(~ zN>?|$b4s02BXQYOp9Mw@=46agY$JEI7*zrR#fYoFxi6h1#F14=a@|4jP!Tc?g_ffl z8f@nzq3BzEcfJk0djr4lHI4l zj8=Sa-1ctxwGXWq{uslE74P%YxIUVRx8l7pu6Za(I{pc?Ujj=J!~qBgSS7f1hWhTa zdS&5bus_{ixL!HHM9o|Icf7l?3BSoOsWS=_Y`En`LcEO>SoE*CM4 ziYmG&>tdEQfIyPl_8mUWrx5^|1()pY5%j_-s^{CcllRO_HRwO4n{@c&y*1UXAKILz z%p^u6bg}D#gzmMUKYv3Im8Uob1lO)mCO%R*Z;1Y|@OnAm@vy<9&wN%Gg6s_l3LEADu>SCfh=W}& zaSLyZmz4Otr?!ilQc9FkxJXDeI0oc)<;1`bze=Tk>3m3-!`r}IMd3xpNd?3wC#*@w+a&^fJH4mk#njaQTtTlo5(q0*V4~ zLx9_zy*ii5kYgT(#-6)=?$8s0-4gjO?@A0#E_z!sIj`q&vns1EB`G~WY;*M45&yAu zs$cNwn{T5V*S- zKSHm|`k02w(^1Pj@vDZ<~lm^9QvcAd2`)64$Ayc0!Q!144PwYca+XtBdgh09i}nfGq{$;vERNBqQvC=RUCfK7eI=BUU8S%3Yy8@~B2g|r zxkoqst;eRXTUWOINAnstX^HM_p9A*oiEGa8`K(`ZN7J`S+73x8l5XVJd}E9~kaq6i z3+5@cD|be@*KXOYqEIrn&x`PY{NCfgEOLBWragD_?kl|HdCf&dr(kSFBS}uQDpPle z)p;s<*8z^0qLc*wf+!ul_F8O?)1E<@DTAi}lN{AiB3kPOsDLm)AQB(EZcjreoD5#* z*(jO>T@OkG%<~k6QO9~*yP1}Rc-miA=4{)X1G(GVVE5AAUYS-)SI26vk6pEV$CrSG z_G2b>mGu}N5drT@lUHba`dF#Q82_4Vm&}ry?3qr^owaWtD8BxsR?AJ}3ks$U zG)~<8{`js^*pgS$i0nT#3`#kiy{%jb+UfauoL%LI3^_)w1~$f_)Jrf;gD*N2bj?&s z(vnI}O1`{DId^xNyUE77jmm>ViF{QVJGLA0Igt*V&}^7`lvx`$EK-s}403xX9%wmf zoq(r7;$lvu&TzwZvP--EMX(otao*;j=E6tO_iGCWmPNmJ4lq%chYVKX`eaQG9<*Ag z!!@@k-~!EHUh7rMqPO=AOVaU%fJ5m7Igh+l=|At^=oTZdklmEQU<62D_}Y%;WXdU!ZoPl!;b26Ef||{%GwMgOo%G zn=^}m7C@;9vMsTFxO7M8h*rJd)2|n75hO!5oAG9V*2nR6@>nc!MWD!e3X&=j9}`L! z%poU6j}u0sFi@__G+~hJL0!Q(iCJr6DA(L^K^vLh@k+)Hw5Ti^Sl)P(!A-f`$;HNJ znHfs_g<@*k_pUtFxl#6w^BRBdO)_6U!OUH{KcckimMt}f_OBm;EB zcqzSMy2RM%!A?_7%C6+15C~>yXj-&EOBQW%<9(@>>pB}%`5ilEz{D{eVJl)5%%5+M z34=nYOd%Xf*}GuN6ACK%zJ0~q9C&XtEmfe;S?!HI>vq0>vZIIEZetK<8x1PsYo@&o zzPAcf6h#kvT{6L}WRpe=*Cy~!{e#x+01S98%V|N;H6~7(5|ga#BF?~mGxtmbg0jtO z2OG!J*Bzyk{errbo*S*=-oARamLsiJfl9qM`%X-wrv>L>7epJ@GWL|P90@^jb=UhSY+mXQ&fV(Ma(==z1f&}?K0D>CqMd*!? z41+guH#D&SM(+#HcpTevu07& zBCw=S&R1$r)%2Y?|Mynl=7;pu3`+(+GGJgIjhK$UvwK7BqZ&tX2aWyt`}glPn*G|% zly|Q{j?4MX?_##x?kSHO6Xli;cNAB!zyCa>7DOkm16 z_c+MaON;3jY;^-P<93 zS?(MnsmhLKsvF5iO)`%w@_Wz~p$aAS)|4hq!&yD~mzfvK>VD|%Hb*N($txq}rM%`Zf(D^qpGj0Ce687n6KYxQI+wv zbBqff$UJ?Y8H10N`fK+Ot+k(4{f1i0=?L#xq}JqA>=<~txPQ$=C+((}KpIw)JRg@A zpIS!GryH5Sq>}scn|@UDMETzm50hu5DuC!;*LJgQt~*F;^L85&ACD)Np~ecPfA)=N zez-jmg!-w0xfCtYua0kf2?>^ ziB6!P!*9jzKfh`-+9NUfsr8OD_o>^vC&pb#>gzJt?rW#q#CrX8LoW7BHECPfat=9q z`8i{CD2kH7z|A~E2BamMf8E&OA1z=M6ATg5*<^3(RwHX8EVc-+yx ziQP{2`Cvb-zHzZg|K+bd>dtp@PHs2u@z}4d19XDPiGzHggMFuV9lxmIe}9BO*Rt_p ztwc@7>chEqiSJm@%BSb!E4MM!gR&DsOyK+6%UDJv^q$_eRP_|vxQwdGwH@?+s1#I=lv zHX{eo&i(E!_gm(&s~@eOz{%VJBJKG6h^ePz=AYZ~fCu*v5e5JVOeZ*kItqNN^yvxr zya6`bx9cMVWD$}iS?@aP08o@DsY#m{C?(lJ7>B+T-|N{;x-#Hcopab^3)c}=oyXs= z9>WR(>5#VHdYkR{2+2YN7%W;+DT-BL7C`*Azfj|2;h~c4K4#HPEvt9Dcmv)+h|Azv z2yZUBtV{Xd>#XP@ceCRMxwLY77=m(|*m=-gB5j)J!Y&Jk$Gwv@;{#&h^FaYEgCRnk z@l+xWG^1qPe$mW%=(mQwL(!K;V$xB>$|B5Wn$zcApu_tUn!6s<74a1uVowM$fK}R) zFE`Eai6nZhaY0?>eDb$eI%tIe`acrvCa`yYwffS}b+mgRzmgR9v6GlbslX}UeSl`J zjpfT}Yax%RWJFW`X508{jTOf_?yCBf9U`V|4;?e6E5thO!1_taLDP^r3)v>Pui!f+ zB&fepfw&G*-a8N{vX~t9r!jJt3JO1lj3-x4I;bC>CX$Yj8^25q(Q)qs2k?H^Jr{ec z#2zdSMD71Q+@KdFL(ig35-Qh=`~SSHm1?Ny(V0G#8h9fvPDqhA_Z&u03?-{aO%E{^ z!b1!<|1ji-w+e*+2LUapZsCBImkT)lgk*)GMP3v)i5P% zY`Z11BJ%;Q4hs?H_e^A8QWhh8|8K#!d6p-9E>U6HD(B$Bi!%oi94ekEItJHeN5(=Q zXIX>acF0vyS-o!E{^Ubn=8aDS7aG~i5+l2Kr$ViTR(oAm2g?LAl@o>&7yu;vJ*ZoF z5lN$R`gT0||Fr-egPvFl{T1H?ng-yp#O|?;*v@=$A)yl!;0E=C5~ppYdqz%aesknP zsv5eZ8E zIC#ps4ogiXZ&STy~^K0odoAU$I0 zrE$%V9M4e(!2;;=2yrCEReejTYx$F#vMdb#Lq(|1VDcu?@Yjg?dD+`Es6fgUxK&9{ zkMASCnt#>-+eR-a-U*GiPs#Sq2ht>1&SbC%&mUZwuU^>TlnhP+^i&{Oyh>&0O{aZq zSz&D>9*)tt&Ch@Q@1$~V=zPeu%iROY(w$EJ+!^=t#=M~lVp&n7i70w}_btQ|qNLp& ztp8=+4o{g+E{b}Y$-Ot-1kOQmJO>Xjh7Vuvd=^xqX)gnw>|s4GDn`7#9hrnnJq;bN z&s0|RlG6dGtwccyzk&vc`;9aZMP^u!FB!6*cIZ6UG zm(8GE_XY|keEH~39u_C6SPm4XLj6Gs_L8dCbY>!}Y{J=pXD!beDekEME~?2hv}PMN zMsl5h+)e)&G%Agc!`tM4JQFGieK*(Ndrdc-&a%!wDjk?fE!d4WO@U;8Z>`y%LwXLb zzZWH!fSJwNAEj=nL6^`eI7>R{#m{LYEM-INPL@)Woha?uAQMtNvFrN%!&l zn5S&;X=8#JV1(8Tce}E!ac`;KYLeY%SVTVW;r4V(LQ;u;%mtMt)@t7mx1~(JTRKAG zeS>S`z*4#66D2SG_#U6FcNk>{sw@vfHIbfD`(CMJUejwTPQK`5&W)?O1T$5)J##>@ zy6VB6d+(B$EnBA0uiu*}hfllijRcE>>@)e>bZmcHMc~&Xouax=?NjEX{Rj|*+cUSW&c`D==ZImp=Rg)A%%*LOsbCQ{x&N6DZ{yL zA-SQ|956dd;oY)(L%(-B>o+FLYSP(LO(mm_r>NEz_xZ4`!?(Me?QGpPZ_fESRVQ<- zdGy?On^_)_ImestK6fwa;1=-BWKv}5zf#-Aj5u@Zl+X)G2uR2GNrl`+D5mAZQTdp^ zG%a*dRLbkW-z6dx1nFH)NjGq3)bpkLTI~{3$FI6|>D}8$R>m=V!or=ZGtaLr49Fc1 zJ+sNg)FQuNugb-qrIj`2-hG^=jPSbC7Wm>L5XtTj7ws*x?NZI-DdP* z*QBQbPxam>8MD$zL5i7Pvw2Q+Q}oMQcW!@;{%*perICIi(QN%`^1UDE*p|vu?QoTt@128Q)P0j{ z>IFtVw8}9kHPFpz*keWcdm*{~Ob2=7ROP3r-oHFqd!fOetkZ=OgKFC@SgD#0dbmw< zS&`m0o6;MaVwd~Eh=T<*kRw%V44C8ddOl<@Ey^)6W>H^}`2%kOc()_}R0vA;wj&y|Ye-CPXLh0Vu_Q zCnMlnlsv|fbNQL@|GWQvKS-`Kps6TO$id(g6u+G?obCFmmUnylJ=uQ9#!{wWOG4eg z;D3$4nYJVhrBLB0q*+B5FANZ%LXErIZDX8W zA`zEel>j9ZVX$-J**y9v&A#@Zw? zaIsK9p}0f6cyCJ3h{pLSwt!+#VN=ZgA6~H`!A!u}5OYL;yje7Dcq79DmKI%QN&#wq z2{GFnewFs6`1wJjh^72MOaqBK0~NOol=D&gLGJE<1mu{Iz=ACYdAjo>zQ1uSfc5@74$6Mxrti5DriX`;jUT)#x1L(PL7%)c+q%=Kgo+}w`oFIG`Td{cIiBNq?wh`S zKcDye8s~MM=XJqkMu(HgaG?(5Q^6v-0&Ep(AxKstzslT!NSOEv7rYWT<TJfl`#7-NaU zWz?VQ5neJhf{|ZhTx=q8c*{&F5MEqdo>Jy;7HW!qa9h!puyK>$7=DP*44I2gQ1OZk zE(!5k3mgSSqQQa5?9tl?C7So_ytt`d5Qm@Q2c|ORqw(U!hrryU?TgFCq4H|zmo^3d z32X!%k7NDfhmIZF9e(p3CVq&-y^bAg0xJcJKGnL(>DsY=SHl;8y!a{~FwSMdtJ0^B zBG-?img)$c_JgM@f+1jq$1cBMe0fX43!V(L{ z>>xSNPInrreY&wD+qZMOz-I+f{r+^ozNjQ0=!_e4bBr_ZhXf5!z)i!RW&+S2oGdyu z%B=%n_Y((qiXJ5E&l+7ZJ&1_snaL7-9YC5MqE(degap<@lb8UdUjaMKixEc;?Mf># zi9MrUPq<^8K9yb%OV~2qj-5IIz%9W%#S5~+cDB^jis-^ zIdCFFwJE_E;B3|*?2Ran0;6Y!7crcN-$O=wur$bF!(9G4G7X1Yf&@Vs7HyfY-8l+1 zPCUcpNxW4q*G60vC@H|YqvPT%w<%osZ1ldMQ?U*89}HE-pxqgY2=9~B9@;+en@8S1 zXA7A(Uh36w<4Fy#0uV258A;*7wuufdu(6a>%t8Xdh8d{Icu@wqOazWo49cPu4_|Km zRF)Zw8KT9zYl3l4^h~9~qHhnf8 z+P~irS%@u49}+?6UWOeWi~@zH%9Ou6>VhawWw4J-crs2TWH?Lwe^^7*GS||S$4c~* z*iX-lk4-Cu|JJd+NmLwYj>V|3b2$Nt?Cm(jzQ#(#LqhK4m|zbqGl(KUuY0ggk1!)7 zviD|?h#-C<7C%87K2G1AkB3KdKKMQnT`pHd4&?5Aj~^QfYDob@ZffP9$!|o=$7twn6?;niX7Rpycx)6YjdFPOkQ$>r>p1`0K zrBqJ4F!p6G2zt0z^hszjyRO;a{%ydP*Jm~k#}w~{9E9Wi5vyk0h(l~Psvwr z&0Xdxf%r_^yjcrE;XY{W69@Xcx{Hki5T@=XKkx!DBG(Ux}musWs#h) zDRpDMgmsxBS_rmUiKgeR{&?K}F>~9+xLg&HHQ#Fpv5UuvQ1vUrRB>pQaZ|}bT=K!t zjv?9w2c~Br$=5!*T$Nh@Bgb9Z* zp@f)0wlMDRLh2P5krm^o#=E}0vX3_>Cs(xL0_&0b=v$~Fxa$HC!w;!vK$E!%rW%LF zu`JZ@KP6z(rifoyO=WyXG1OJkTFR75U_PfKX(eZ=M3K&wKlNnyFb+O$hy^T| z#u51%A9*Mx0A9MX%AkmoJ?@bkY0D9{*VpHBW)Z1lwDtBscH)j=bC{(*$TL4yhVV4F z=fXqC6mb2>Cjuof!hE8>A`V}0p%?%)in$x;-la{0)k;?&YUzih0pmhC}woFW|&#?2O9vvydf<`1>Ky>e*6Ct0;5!a!zN@Be5GZ}z{kc+-S6YBhP2lDe&nDf zYI_rKGgwSpiN7JdZ~v2XjDs`HS>!GR#-VI&{U@@E%&cc=apEJW$8a5;Me%pL=#Ob0pA#0F~e!ilHXWdmhvan$a)j845-J(>Ys=Wht!YQ;rztv}T3!A$WdB_|%rE{s(;OP8G9T~% z!3X9#8<-?cQwLMrJMuz@VS;d?@kC8kr~|v`xZ}#(yj{XLB8kMp;xF6j*#=)0a-!u- z;z`GuET6BQwrWJzVHkYjc-(>OiJUBnZtO$t^D>8@&w}c9Ly~Sx*)ObdnaY4BgKl6Z zYQcI1(2euAJWa!B78>8FuXC21O@SwO1R@#lc9v+`C}sf*mv&SIo(40r2i4e#nrtcq22{8?QEx``9IJt1k=LNxB*Cud$@S@QHuiV7VT)- zQzW$)Cfp(XW7?EBs&HGw;LK*=W5e}@RcX$nqfFmLS&t>U(2eK=%_(i9yMcx%4*se` zhbp<&|0$FxDfcGG<@W&fQJ*)CmJ|uSn|mt$e^4btBBw(rkIG(}L_r~&RggumIVtw5 z^7WB%&8|J3(vwHdrj-VOMxuG5y@hm0>=|2?qRQF1M>>%;5!Apm19WmlGArZHGq~&B zp+)AaSNkbjy6bN1g5%y9|N7__?LPc|X|`cGzHt(@Q7{7=O6P`Wzd4oX^R{|Vsg1eh z7Svjd$2PrNO?`u2q$zLu-zVLv=H-#KQr2vwVEyzF62UKoFJdU~N|A6bg@lMNM>E0lSJ6e?`{3+N0#Ba!Qh7I;R( zb>|HIce?(4*X#M+3Iu4h+JfVQvo#>C>Nh+hL&13b=MygFf0j9GX7+@OeO5oZ4KGZd z(FhZa1I6ke(}BrIh3W1$4N+RYg)b+LVM>%1+z#sfof)MTI3`JK!4{*Z^XFk$NSjP; zc;q3yIo=MZ4V}oHVz)uwXZHA{gW{r&Eg5N0vCtS7LD%?G|H`@Y!@ptQ%j<`irH1^_ z-#V*z$7Z{gj=uI9=zbcPuA(0o{@r>wUZrQp?-v#vYcO!S-=z+|t#>F*?B7@KL)kcg zAYuQbPW36dd4|pQfB5>xtn$cO^-i9KIT_WTZ&x}OmfW)Jc;LchgBde~j|r^eAFF}s zAV)~%h$h!#DXOrm#q(?mUlor#u>!H%wsPgl1&bEZ&Z_<+`*OwB>_rc*C;L&#z7@hM zZFub1F#!Y3%*?1C7(q5|+7x_CM9p{zb4xnojyz!C(+6L-Z9ns z@3i8bvS!V6a3}&-2^g>JFGKDldYc0WnzJecfyq&^Bvn%C9rp3b<-~2W3-yYjMfNtSf~XvyM1wS zV>r6pn+Df$PPu&fGVtZ;WeF~vxBAmtDqn}Mg?J<#D;r93S35h+ocRxRp;S>~eKvg% z6AB=m=55+MeE4vmxJaU=t6jadn;T=hcDO=JJcG;zl#AhS9~`YPBZnznG9mNsT}7N} zb9OXLxis|aSi1j`zqXq;VQ$|Q(<|<-=m4KQ@xUt|r_R3F*KxZe)QAf>{I$P*Ce+mb zLx2$T0Wk`|`xK51kW0#T4Z0rIMZMnoNj;}iC8nUvL&?KL7$I2YoO#Rs2f_n|RDQgZ z4`;SD`VY)trOBe&)8rC|x4(P;{%*odY=XP{LgS9*B_%R_hpYVR)vFXO=N$};m|7xe zp2Bjhii3RBD$+Go`m{-t8ZsZ|uz6!&C?i1JvWeW9jtki4!t^T<)*2>!qi(gqx>rZ5kA$tL zsFBQ+u{MDl;Xlk`Z>yi5zHi0EbYm%zISimN@okv%j-%#M>X6nF*mm>4+m)}|5wwiA9iV&#`&Z!rAHxXLvZ7Iy)7z| zWyeTfP?kJm1_KWsJlG3jf+AD=hH0(v(z95<{s|41s!s@Zn=;b; zuj`sR6~N~rbtS7kPMsR;qX8M;&!7uo0?a`mnClqc5u9jZxpCrh1y&NB-;za7^sdtT zQM!}BR*&DK-eH0dR?`I$*ROo&(5ZUbB^lIT=rTL4icn9Z)a z83;FEyQnp4bNT;s0azw(!4fzzL*WeoK4J4v62wvu$d|NcvTO+}qU>xxfQ_>D9E&Mg z>lwqeV<$6nTdV#%FW7Xt*){jy&66kosif-{vgal}z4f3$gK*3ywk0Mf3mE|5;gJ-a{N}6)Wu`rN@L*D# z8X<~9Crvv9iebZt!(F6z=$Ff6s^;HwfiB`HSR3rjZqew_Z=lzR0=e(QwPc zu@@~f%YOd+DGPxxXPrT21DHp>nNs{5+OSM_I(6yIP~S6KvM-9#en`iohk}!w%J4Q6-!!p^CG$;{u02YY#?Ja z#gaW|p%phCs_TIxNBolY2)u+j`VOo9OGN|z;)P~xzjpEWX5~k32nz9I#pu|YR~7=2 zVE08dlq2`(kvk@O@*71*uOd)|x4#4*R-8~lmso=|s8WuC*wGObAWJfqAc!8Vn?*K} zZ54nI$&VjPGbauQ1UF!RvBr7+Ty|ktmEP;M=S$pqL;rgWmc_)@IdspS#?%tNjoSca zonxlP*QS|ei~BLH>DRJnbnE>&?)>@w{7W^sC>%wD_G&fFNxZbk!iXc&9$k`;6PrZ+ z)&~4-rV&&(Z)J_jh}|GPydkP#Zp?dvg|{!;DLuhx#paI38+v8fs8M3d%KauYI|1gz zY%V|1w)XO9=P}yaR0HzXf!&dTvt%!_ZN1sI_djfV@!$GodkozGSqT~vpXp?9lh+VG zRp@?MSaUAj_H_!pUS$Sx=<;tCYktGHShiw0Ii=g z>`YKkw6w7?!VW-|*O9y_7``MnY$MNumq z{(RF=78?Zw^Z<@oN!yN2UtR|`f9iwQqXwHrO+{oAy#m&=$6mfP_4e`c*WIoIb(~gd zH)j~X3L_4?xLENE&~n1Md<4dYX)ZcPe_5mnTUzY@iQwAkfe8|fz>hOpU~q_ip_q7| zLkmb9(bK88Ep_#wiapa>%Zdbg5QJqSycXA|yu9I{>y%T5G<3>zd4h^!p|?Nd=Z~i@ z70pq{@Wlm|WHiP`sh#rK9O<5pZxqp^X8=QBtp0{h()VsSh2c>SGA=Ufi%m2`65&_XoBV(EU=qEXPGk)3Hgym4o9OWX{|?O6kA}=PT~nYoer5AS}zq zV=`t+N{Vb71UZp?E+A+roGonx(O>dzyWKR;;$Mhkh-}Z|$-uX3Pw?Y#{K87G-}o@Lm-7n>-?2r~~;P;P}2dk?q;{}Dmg2JIaTR(mJq&(DI zK}3gm!*VJkfDjD^34RCUcAqFIF8}nlH2E}AoWTd=TPR{*!P4ykO}M&FMF;weih{?F z!HuzrGIbo&pH^ZO&!h?L$Ft)ppkA@>#h$GSG^v^x@QedaJe3`!oQrSed42+&;pdKV zbF5Rf8frIHxChps^4tH#;dH|}5+|n@sQNA!?nZd{S*&4YvQk_;dFD)y%xzR;@5E8P zurP=~7>=Kwh)Af0kLdpXmwOySfDj|L|GY_1S%D3Ki4!;whIdp$gSB9iIo z4a;PX#h!5E2AAJ$TW4eQ;N{Cs(3{WZ<#p=FYql6G9=EoF0}Zr+r*$1~_VW(K)VOo}NDzf}~R1l$!@IHF*2vxIJc{ac@a z-#~*tjUYsIzZ@V-NaM7?pvQ8l;DiI|@8k46(u<5aPlB}@{|KoIuJkx}N%cewWyMVq z)Lhz>uiOOsd(>gQS(yub%9bX|+ZB|eREP=>wDg`@A ztI`QjK8V0p{83_`{A8K9P%;B9@KH+ay)g|d0g<%i8b#i`>CIwQUWDb1y+4hH-Q@nr zDaNY_5aJ)fE~)G{JK!^Vgb{i*fl4kFcUS1~JTQ*fb-V-s4EFf@D{}`{p_T!lc$(nq zDzh094n6MQPs{3sz_7r~ECTQq>q$cvl=Gm%RbB9~!B1aCk}QM;Q793wsI;^hF0P1m zFo338J|L#0uU`FI+aACNZR}l{uG!j(N6Lmfzx1;vv9DpV+J0JtWZ>s7EYG1vA!Hn+ z6hkAkPh?++EKMZ*8BCvEMD&tnQfwk1iKUZ*@ksD1 zI95I;t1czidc2?9es2r9$KxkX$QcL7lPv`NJWG_u+%8!l2@VR0;vBRGDlO4tqa&-Z z4!ZyM`;Q-`bQqw8?wF?1+zKt4=lG4oiK~dquM)p8AKTaZUcY(sJqPC0IOk)O6GXN? zag(OJ8P2fPO_~7gEOQp}u4aLj@1Oz%@RTzrn~!+mqAY~zfi+kW(t6EXZ7A>7<;Qr@ z-^$9ag;8}>+F4JZCeqxpO1>fEj;g>qV`p`TgB49w%(rXML4XlUWkgLU+m+Z_B2qId zFEQ{3A9@52QSsgFi+saTqejt?aMOCdirb<=%MZ4GgbMM^yLXdXVz0u#L4li1F(`rTNWG1)(YkDz4 ze);kkRY4#n0Y%^>Zn$tIMs#JgG7YP{8&OdugkpKJ__N(|C=<(?knx2s%fQFp!%_|< z$n$y_1*J2HWqt5_FmQ9G8`ReVE^&OOz~%rzjJ95o-@53sS_=XEIZm>+Qeao_c0VXL zEg(yT4_(K|knGB~&by$YGZ@X#-?5EHIwB@bT_{@ToHN8V#EW|~<5nUV__Xo%TuQbQ zwgHCXHy$u}a53Rw@$Ieia@~m^3XE6vIrPBD5-TX4iomp(f71Q>Jl{sDHTjceKJ*JV z&h5H)9~@T=9CDCug7+-tfcza!;g`&*y=XMjNCQqJFJ7g{@n!bMJ_j0ODwnrZ!i3;7 z6@@bwouw4#zw3v7RA4Cr_0%(fFRqakB?_Ziqzx7(8wmEz`m*$4axyBAnfSViJ+q9q z6sVBpkdj2a7|}vXg%z=_&o8UBYu3Os^LgP*yG$tbftl~id&wzUO>6Aq)1GZ2$Xnbw z^T2B#*jUb9AiwA8DvEa4x{0+J?K7VYmeihJPfh_V7&}Z&OG)Vs63b?5F(xC?G8zi* z;En#(y6?J3&Fn(NCg6 zAJ{Mn2y7eIE+#Iny|VHF=)D2m_(uRVbYRCv+ZOe@bow`wuNItge=8BXOOb=t!6ms~~66lKgRyH40s!Vd)xmc3dW;+rllX)if- z{mQOSj_Rz3P?xhO8d^r7{`3VPv@(j|O9(>Sn^}vL63gxwa@YueAdhwoPy@Gq98k7# z1DmUvm(|r)y0*vi8#t2K?fnUZDm!SQ(#_yIjMO9jE^$5eSIhO~59u2F(I&f*LAy53 zugr;PN(q8AofgMu1wFY*i{}N+S5&-s`Lb&*dY|FWHR}%d9OurvlYIgA?>AzELv#GK zs!FUlS(G4?veeSSy0YtrlJf1FH@ycX?y(MZd_bNAh1*LWV_TBZwWYVgl+0N=A-xqu z4}hTxe;qxC>mVm}72B02biVN;e{JvY+;*B)S;{MpiT>Nzu0WM2mx9YrF%W&IjGGK? z#E&Sx+`=MhN5iHx|CZ={=y|4XKTPVzsSg;I^WsU%*hdmMriQCgt0^YQW$dNX&( zWVBiL>wG8=KM6n@YLhscq}j522#~7H-Zs6`>Asyfr;3)1*FnVq=hicun!@OGA;vSp zWx9FuX1K{?7EQ745Ab%tz=8J)3Igb;aE0yn;}p+EMi1m&f;s!B=9qro%x)yr9@lL}jP7Yxr%#=#+<95EF7>C6+pCZUEh*4r1k}MfXp<} z+WJCdnG+o;%y4AGFk}SKMxca9655NKR(**eq6ubjnq|q;x#t}TkQt9 z@0Wr-QMnS0Ow-pgT7w})H!l$b~crA{@w$0Zeu~LXw!k3lplX2RI%HMOk2fJwwk;mz%&t&#;TZ+s9$w2 zmzH;^yi$rE11kZ^D|8>dQO{IYf9lls?3tr7wsiaOqj!zfnO~sr5Ak~OXz&a$DEwp9 zW~D_;!HJ0%%4;1i_?Z&DXhqJHIfD@=FTsc9MnKh8gTVMm!f3jFh<~%d3EXS z7LmVFm@!L1LW`cW{MlFMK&SW0FeQQLXvw=W29j&f3O?XFIwi%F#5ZybS!*EmBDE8n z)OOG(b<$Y6Eh9j$$F(x*HC{glL(&HDvEmytHM~1!+K8guI)pWbch)dBy{}AVcotw$ zGUbPc$N*~iiCJeYG++U6ROygpl383SsgKtNYyn{bv#81x^XNCR*JyM2H(+B;t%GG6 zmi*#{FEe3-Wi1nIt4oWEg^BlT+`cxw1ny@|v8kHslqt4cL8_B}tYyo(!`@ffNg^I# z{t8VcRI`Q)`(RSi1adYpAcS3XoT?t&_%a(FKUz+|=S!HZLF!?*dOCMc_JEiCq%D#t zAw@BeQQ`)Nx!ipC(crZMqf{}6PiNDVlt(2_1`g6m7i@}V0ykiw@AO>L35_FR8W3nS z8o!CAF9b=D88#oInX`&eM@62RmKIFkM~V8{*yoz6&Fgqs=5vq9oW_qCAbqmtg*OEh zc<$=eiVZW|s95cQkFhA6L772ODJz^9G-zK-X34_WcT`OVn^R;YgX& zAS87HWO_kiq@D2Yk*l(NU`^QgT{YPbl*5>HNKsi_Id1IO{a9b;Ozj_#Y>lph$2?hllm)u@bD?Y?6Hl&Z3^d)QthvL%Uk&?+_--IbbS0Neh~pTWmS5xBaJT&mSXQIP0l<~K62>L z(Q;P>(CYmT6LXsy>>2*=sqcrIu4w!8QF{8F&?Tm$QwieB0dM3W!S^7T3XlzM_#h`| zVb%2(%@*|=CTmHh?JzhpZ}#loWq$WZHgEbmH#he}iyAs1A;ys>-j$T}a~%@%dCKfh zKN=XSsj1w^ zpeaXCSjn=DLeJe7S~!!XofJLPFz95xi6%~6q&s6rh9&$glU?_pKkq>0q<*&-_|gB!{2ZR~V_(3oaejl4e1 zrddyW>#(sus;W>r?0`Zv#m@J2bFgtf+vMi^7CaU^>^_}DTm)+eO+M?htR04STpwND zN9Q=z%m!x#lI&Y{bv=Znx_KYi3i^w|u?xO$tsOT%o;C2YT?EKugW0@+y&azxG5}-w zsPn9SBe;;t9hmR00w6Z|l37sD15F1wFo#m#qV3v`JY3cRrrf@I`m`JE?ODdOH*VZz zxf3-K!=>I~o%S3*jRF3;yGFBSooA9`GNN*+7*cK8R06;%=T+`BY)&>XnVOmkke}#D zVR35wio(lRudbv^H|d>CXl5V_b;|%=#-S08r#*JIJk4Z-=08^BV8eMlv#VFHa%j30 zL<6W~(FXH%7P26YnvibF7ewzm1W1M0eIFfk20Xfo?&8uVS&dQUw01;S0-Kdzquv+q zarsJ?cgnqGN-cO-NYByH(V1gA|49@~no`F{ZKjHr)-l=|LabVps1;HkKbH9ib>(ap z=mI0~Eg0Y>jSM2+J z($pQP^=h=6Virump$?yFdbg$5E6SXP>!-8=0oLDf4OI?m4AQ`VXO<*J?`L24*GXiO z3Ik5%<>w#bn}6OphDEQ!VoOVFNsiAs0MfIYQy;R*e&a?@o$gpE_Z>cbj@?I)w}|b3 zHVZ>|^+UK?;C|tM^95LiKR96j8Y`=0c-@bsvf0^#%8&Ta5Rky&;t0FqX*mYj`T2$b zNKaQjV@n~BXp!6`ILiqqLps-`Qgfw~>tGz}J69Jpx_A1CUx$v@={SmrenR7*46Amn z%cxiM%w9Bao^XFS48Q>FI#Vxeq^LaC1qW09dUUhc6UUEl#EzfSJ*fOV@Dvxuv+szIgDQ)gYWBVKbmO8E!!2{dt!S?_QMqukLNof;C{8$%8JN{B z_r2ZLt+%{pi6{-gG~~{qX7T%fEv0Pah!i^mA^oClKF%y;+}3YDC%sE37`dK>%W!5D zheD7CP9St>wX^D+-FN4Vo!PZiDi}qld}Z|k3i1>%`a|#0XoR-%et|Q=ujNIncJ@rC zqGI?CRUyQ6Bt$z3VyE(`Y}wHXNnCobm{lRg|DOv$Ey$yZyaml>_ujqJ$KL=72N_O{ zWk|c*G|abyKhc57?TVEv_fl(256!wTX5jGQKG?ekn}&Hf*xt8~Vsyy^V102A!w`@w zq4LWF9=Z+hUeI3N@8H!@x3*LszJ2@lWxJjrxaZMq1gES#HYoW4 zx1t!GvFHgMzi_rGsh{70o6I#a=ok8|%#kUn)!UZ=FXNsb{Y#(8C~?o&;WUskn-H-L znokK!40);y)gBZ86@vG{$q_aQV=dOr7k6F1W{oWBvd-&uJ3bzQl<%iWhZJ+un?MM% zf|3*6cOy>}fDhXKgw63XDzMJ$&jb#BK?poE?y&@+v17;LqOy`3X=g+v9lAQJZ`t@! zBSxfDw4-@p;4?s79gftuWA=`a!YSBBBIbyh;#1>N^-i~f|0Z?GnxmBdlx~|j8AfdeP}xdHjE|3WG!sdy7fwoL1;vR-y9rzfyom2r$B3?Ez5kM%!6Zmbcnb~ zJOb!k9PjAGp;&lhF0YwEnF69+fp_^E23H;4#vlzW?g!n7IoCQD z=&W$XxZd;Dq@wNo5ly%H`FfF0HJ05u#ZzWsWZ&_(Vunn3tMXt#yBpxBSeC}lIxvpF zewg3UtVY9j32nmQ#W7E<3+MjZAsQuh^}_ncdm_ z(qvJA0nsZm+bZS_7U;Az;H{kpND4QQDfP*J3n6)9ZP^w|6S9|jmZU^Kq3BxU%O{S zX!(04s?U9$Y;_=M+n*nLCpSFL%2FA#^G)>a7@Yh`b}4-?9Up%-ViF{;DOWq6 zIrx$o@XPs3p5pqtz^5bH+JF2}vpDmg->crRrxp4m6FZvF30ZC3ivp`lGiIDw(N=FG z1;WnnvPg-t-zAX14sV2TF5DUKgxWhp1_qDG%}D?E*6u5JPBu!Z>{nH}8`=O0-~}-o z`(Ui&_S{@4z&v5To@$E|xyyVmT<$!+^U=Qd&MWP>TXV;F`SLTLRsKdOEovnU_gzi? zk>#AaaW&a}KhHtw9H|;mKIoo{`{cuJr>E?^>Gp=Kx*^AJPvYf&H@96hpGtKlT@*v9 zl}HHrJjYes7zOg1} zj}2*yIvv~8J!rH~zQ?kcUu#?J5f4w^qelyO+R1+NLZ)>oP@l49m68D+(jt1|VtdVz zBd1);uD_NGoIVKRh!`u7-L%kC zL!R39`uNk;c7x52Kh8V8EOg0G+duq!L6+5~CdS`}vGR*}1gvxHb88A&SXmzIm9UO5YSH@QjKV9kdM{P234t*n%vlP! zS-QXa@2#L*R z%t$kOs9uo?J4U37ER`qXpPF>V?lBe<+9Vgw()aAE3wu78 z{Bcu{v^_kqLyMeCE^!`Mddyj{z%ftD8UTh)Lz~ix1M9tFxp3|N`3oed-=jhK!~VOM zkI^xJ`!Cc=NOxx4kJl)r&w5zcazX+CPOmXz;^HouV127;ce=5k_0-7h5kJG)wRhd! z@{Vfmm?o}KJJ{l06p^88LEM50@C>e|tcoK)2r|R`gN8Bd!bb>^hmpa~o;NR$dU8EzeHx{(3{U?Ni;vLEepv>RNt%clgJY+3}ICSM(nKQ9t!Q z{k^+mf|3>qKtE{yvSk{T`&bM?niOQT%5z(PcnJ+cne*x|G40WQSkbL|v!5I0o@3N^ z9ZZpW(RKbccrltvU#GIf1s(lz#-}V;OSRl{+J}NzO;fY#a8PE>xP>Gn zd&{5h$;*Grn3@fwnqTGv9So??ZMsUQ+p<+L1xrG_!Et zA5ua6_5^-GeutGv(7n9uk3ar z#{%GM4~Z5QQxe4;JV_>Kp0J06NekHzZKqd&e)+YF;o()sEU#RzEQ;!p>0kLW=k^%K zFK>GMI9?+Mjc5Jid9C95=GO#;9$WYGou-=!DUcMmnt(Q9#J9s8#sysYp7*5b*B|9I z=L*d-wyX^X+xm%a4v+9lh9(#v*j2xuNw5R&8a92YqW;~8^qm#a@6L3-#$wjGg6#|Y zhEAD(m#ZUwINN`J>jM>*t{hSrLk3yOZ3(v^;ObO_8?|vYg{qBgnLgT>HS5>t?UZce zRrC(QbU`!{EX(L!J_mMuMMVXiRAA2VyiR8he1Dzkg3Zv=SFfgR-(IV3Z?-r{sj%X+ zw&T{V;b?6Q##K85+LiEKqx+`h^)FBUl-t>t>St@KfKA_D|C8JKM6GX)!Pq>YRH!wB zS~e;uC@A^$6LgCzV#fBV0M~SO0$f|(yT)(AT&7)jI9nP3Mx@Db zH7Y+EE}@NTzr?iz6&@tbo8@*A~Z4 zw6J)$zH%Y5oy4Fa6>&%Y#2oqav%cY>_<2yPIZVI6o<&53Jy~B}-9TZ}(N9x33U!nd zP)w_)xr@zsNV48DnCU-!__JxVjpw!<*umXs__Ol%9o9tszPb2riz#L^Xnob_mRP7H zmSvD;Kupivym^c<(QLJ~lReG0fsy=S1m2pE7G8U!wvs1!=#u)FiCz5CAH2xQy2poY zxOhCi2kFFP=Zv$D!nT+0xX{|(-rlFdd)LMlQI}>P!Z0=MW==xZ(-za+xOxoi4bhb- zh?&KCk@ilYq%i(6yU`Bzir$A%2-mmIh!JyJJf%nWA}u0h5b84}FuWDngiDrQpICm) z_{{7#rKPgm;oLcW(=7**=0Rf;%I(eibuhJQ%C3)`xPE(bD(_}ye(m(Nvi{|S_m#Vw&9*7H_5JEnjhFei zA1}{5mrr&f@#A$TTb={~qplRp?Uutu!*+#(?ko1py_x}#X zmpuBuw`F9&X1B{7J~s?G+$i!%cFo$Qp(ZyVS}eRs4W34SCUzcFrPB11GETw6SOEog z(xhGMl3YuC=P$Uub<_Dz&0T`k%37hVVE8kpr(NTwmAqqERo-)3s3_mGOl`N?;{NM? zcj_Pc4jPN+n#tn0h_B7pTubl$nMve?4?UvJ+r+dN?tQ4@UvU#3Xf@PGZ}vlmblv)W zuWwahHD0qz?ugXMihHlma((f7C9D8MHL$FNF$~xH%hO41g#D<%4F55BwiRBKkdolh6ZtZvD{5X(p2cJfu+hM3l!hRPZ1CVQ`LG123MYC9Qw-_awuaGrxZO zcC>mPmk<`nvsvRk=$ye>WpNcnq*SDI-hdOO5XWU{C8|=yGth2W&{4&(u3M&4ROKzI zsF*9q4ll^wvNf=1%h~;GcRtRBQ$U0LpgqrkHy)O8z?MDVs=b| z7{)%nG%pNq8CYrl(B?!;T%@5~*y}c-zLUY*O{2JX-u^{L^g6W1$KDgL0sMwsVgVM2 zQw05Tef9YIqy)BRj~}^Za=Y1|Ymr5vSc98*=)(8i&3>0ZbxA+?=lR0rCZq&t#JzD^ zA}9(hCNdkEaNshj5!e#=>Oye*?eb2}R(~eyz3Ei1KYK>YagVT<9SXlpHAwILv((yx z7cV-(E`>Jn5SDAjn%iaz7QA70s;tFZlw$)?GzsA#;HaQ^vPOXri(;BqPmI?|S(b0Z z7F`kaQMmb9xk>l^yLoo73V3T%qr_teArayn4+%Atj1iz&9OkfsH9}Qd-JbJYl5?Ac zO=(x~;ZTEw@*5L3eA%qKW5022!p6Xa|7EESlL9sA4P z0g5VmZF?p^TteZNLH4F%1bQ0nS{L7{b+cVVJ~XF|5QaB7oqt7u!0DyfcmC45`T3`S zG!PhBeW}+1EjZROlvJh0Qij`0H%wmN=hV&g ziZjiQ=U-H*Tq<}c`CEQLXjwv-f}c^}@A={eAT*kZ9T1!Tk4?C%!_T#aXP%*}hF=Z0 zI}u30?pW#j_tX1ZKtblbOuwzjsgMN&AbMzfDCgQ?+{7(o2uocqgM~melK&%pHHQ%a z&}p({i3bCJB87R$BVWFtp`js;BH~vcKrBdE!OFb9{snEmvP(0brDMG;ir#0!(Cqc&!G2# z7a~6$nm(Dc=(|U|c0mH3ap^9sGOmRY>@-F>vaC0}DA#=Gv|AdA6T(y)obJ)4D=}AyDuIFN zjIoeT$@q}dZft7$0ODE6_Rh^~`MF#+c?%C8MzG*m`O6Ruja72-*j|e8h@|H+VyMo< zSC`XuGV4C%T5F6|_r6CzR|1wpADQA7r{p*H;#`3%-}ooMput(jHlqGhXtQDge! z>MS&VA{(9u5Lw*Em|m|l=9dryk}3xN?SS)OIa6ckF&RhfFu;8P^?nM+fJwWobP{en zgE88&C*f9Uup~v-$3zqcEL1AlAPk^IM1-Bmm#}95#8KiZp!BGw;1t#*I0w2B!Mu^- z;w&I~Ra-kdEYr;)J;FoK(lRy4=iiCx*_-I7ZVc>s_ClwF2aY(9d8f_AtsO^PWR8gb zSgB*jeqowSa#pWf2cu{=`N?G2vJBuQ)FStCbAQk3zpJvzie9qQ81NQuz$sxgaCTs; zAd1?i2cH$%A-+uJ9GkfO$)S~yV%ZTD_iZGghw_-smy2IRlDbp6q62Qt&M()r`*d;m zz!Ow=s%CgE!hQAqv>o*Z&LW1LnkGMdm=;nMbi-KXm1+YGez2Spv{A^Z03w$I!h%F} zyFJUgXH;Zlic_PM`-?a>kSFvj`3Cp1v*$P23ms`X^*3k;b*=E4OOMB#tNGpK-a{=4 zizTBOI8vCW{!x)gAGb>+p-8Tn~3I zL6<$8NYs7PdG=+mZ^vD}2RD>jxe1WFgF|HewH4!%A3TsHP00TWUF&a0)|RjLay#BD zI%3w-P7HD$7XzN~w1P_IyEb2bkB&7)wlu&~%QJyd1XsO*2Q584)Y*kawR zRp005^rv@&b|4x90CK#00zseCB6620cI(y+j3Hsr~1G3GntPEEM zNdse9y`oe#IIM<$M~?G!<4bb^*||WH`wzS0teDflVq&PHfHK9QVVt=ZzTG6Gjz*%!JtP{p?bRy-(rM@w z+FVNl9fVXtFQ6Yev~0X)>fh*Fk7X+LPPcB^&TkbmcgOu9wlq*&U!?S5lCaQrV1wrO zrnTr3OIl{U$K**?x3O5un_MESV_MH^@l`GUtGC3Y<-!F|qg$gVyH9f8IHIxNwpH`& zc9t@A@F5~0?qP*9U0IiT)7^&%MTr1TjKUq=Hzj2nxsWoz5+y3U^CKt*p7EH=S}BA_ zOj(@p(~Qz|U|>Rnd;7WC5Pa*cs8fS)rT?w>KRY06+6m*z20N5`yZ7xo`C*4$Grwkx zGjCcHZkx2(@IcGWhVh#}TeiQJX0x$f^tyE~wT76lYjsB1@S6SzRiy@v+I`%#a>V!7 zg}okWdwEXosFj}n!^yjJ*a_!Ko<((FwCk!&`VwdWP~yhcetRaaT_;We4X%DRX3^*z zMgNZHG9$k=gV5*YZ?VLaQ%AsO*+2$oM0ND&1>RQ}-rm3s zL`!*_O|%GfSTUMT_E>n`e8=N+4UxnHqrdyU)vEVIhM1|7otO)TECq#P@SFj9huzN&wT2d`MmUF1Er z{t(1IgYxiAuv5t!-uK^Z$+z9N*pCscC>_wWi!l}7rSI^juK(^|8iL(W>E3QOr}=3} z;1{afYdcTw^0f87Nje|z_iV*hf=}NG`mQ;5yZ1S=sItG8|M?edhIMI0k6*NLbne_) zKup1x)K96aB+_IjdkkIny21bF0*rc{(ql_L{+n^1H_&g4-umrdAz(6_A(Mnd8xxJO zPJo;YltGcI1&5FrQSs}|(0_){$_#X#6qA@9cwyXkIbagv1B@kA{%@H1J1VZ3!*3r~oo7zV2JI434V3j?Estf6o>sEJI#1J57-M!=Y%6sMZ)=g*yr(>xKa*dB_+zp3&4IS17G?iii(%$S%C z+fuPu`~xfEtBnk_{yo&)FRkC7M+YOs`|PVkW4mr{`b#8ljpp38!TUwUTy z+`BB}5MmpJjgN2MhJ*lu9rG9tsk0-7@d7yF#-~cTUFE+jC`ctP;p$bRl^3^!uA4S( znmJ{?D8^vjDxp1;?UsNseV2(Je{pop*4vu)S2mps%UO30`?9!#*N11lH|e9f7S3F1 zh=ETRNE5}2%ck#dbsvUKhbha^3u4YJ(O)c6kaZC`Ee*d7V@e0N_s6<=GE=Z~%&z-BN^ld8WNX=!xN|Ff#&MbfJZ ztrS3jK|HeE$Eruhs+dgy?=Bw;zz0KtkcG9Iu7>LaUq_7_LA%=`vF8Q zUp8!_hRTGE*k^I925+y%t7@^@p2e+q4{rb0ktBhAU=S^#fz}@uMMmXyCg-fnyE#)w zbzdvxdpvi`wdqLA7U<;eUWex;CI$MY8Z@4V_&c03T$tYnI#;aD-fWU`3bT|)PCt`+ z)#>3bjiVu?x*!fp)XbY#9ADCx4j43u#Ws?kY+_3~4e;B*|AN+L@IzzJ3x$qDIq*b1 zI$zZKf+xae_9+~ssya!{^}s+KW8;3`Cg0eyYbPDU1*}Vl)A|N(t1`ewq+giwp2r>3 zZ`^lwkLqo?b}F{v@~&m#6_oK1CKRR37~x*x`V#KmbWw2uLn>`Z=2J-5+v(6W-HZPo8J>SH%QN6pP)#VoP}NdcUA76qbc`e?B9f>-x3SVn)X z=5x`*X5pMU&-ON3GB>Ton7R4&b`MJ(bdbWzE)^*U;LI-FxG8B5x>jMh0Vb}tw(d=n zp|h=U>!&MPFD7-J*>c!+Cns$HC^U64b6hy@C$)Ckz`s{_dd8hS==QXfSZURG?fkDf z2aaVweW|N;@=2C%r{D1xYPQ4`pgz}Rf{Ix+@eOjY2QbSLyV6FAY!FMeECN2`K(Dy5 zK!kMo&3Uf7atm}C0$t&I9ng~X>#wvrC3PxxSre-EKRG^Z7k#P^t=8suVc6c4Z}!0> zhHGZk?G}YKAN5VEvhU{r=DeBDpHHSd$KTca_tNRpThK8bn~?kR9dp=se5}dd+X$;1 z$Mz(OvPDad-LkqWKd~U>^Maq*bS>RQjm4}&>=IDVhz?MB^sZy3E<1k&N2$+Wy!bJJ zLP4?z1`(rSdoYv8x?(m;8epkKX1P@PXa$T;ziQMrelHx5(E;g3o90L44F2_iMJ&k6 z?yl2{Kx#xyoqJA&=)4i?py2nJ)<<7_F1oBdZtIAI-!HC(Jkj{PrtP#2g+|+2T4y$% z{ONPAtno^E3W|+3j+}IE5R4+R6(oG{<(FfDcc|yM6R6L!&<^|m6V3;SNCH>J;7s`d zU3QJH%n0O3T71Zf=+EO(fw90rAgT-*G@Z!gk`RRQQWk#O*@ZvuB3gMcB)x<7^~CKJ zv=imwrb$UjVqhRkb-)~s9WG@E91<`m!%|b2{~{V*xHr((H<=O)?JN!wGkY2$5qr2Y z;%)UW*hVEc^K2lbnXY^^lVb=z^r$XG)B_w6zZaLzY-(~@KUj?47TZ4=DIP2xJ5E8! z%x{LY<~jm!vdx#Rw`E$NXA|c^g8J(+9$me+`UBnizxJy=%X1?3k*)k#4yU#7ZV#)G zQ57?{YC&s@<)vFiJ)}4gC*NG{r5c@^?AigNm$8>r;dJ8oJu>0KC(kcuMo~1RjniOA z42NV_yRmWGKR+s`&YMRUGZokcEy7C1fZnf+&p(r8&mdW|!2Vgk^M)V)pu|pQ1%`&L zs2{`{18|cq-iD0z7*6*=k1i3A5N-Qx<|wJC^!4{X_7s-87yytB=-%a7GTmX&SmHIc z=UcSGz6;|q1dPrrc;Gg&ODhcXj-z|R%SgES>{8MOfubJs-$QClUmD3^VRK}YsLbP^D=rJA8@EQc1We( zz0HpLMGrYheA$KyoyrLx1hufeROvh&-z$EL{4O4B^$hxsrg@?Qb;1yNV}yQAAO@U& zU0hFUw+7}zf*=#s)z8lHa_)*?=$dQoT3Qq~t|T2BM&QL8_yB$lf|}dhDpM1|J;0y1 zK7xC*f{){Y73}F^eMZ>lVN-XriP=F3Hk6NrtvUl36Hj$MQb?howpbJT<_`uJ>O)34 z$B$r9qZ>s+4-8&t8tITH&6_7OuZQO^4P@zj``v@kGgT64A9}*v;ijcHfQS z{{TM} z6r4i%PI}uLwIh);eyIkNN3Qy&)q6ht>ixip84GHL)F)4bSSd%>@wOn$OpwKsC%mI? z@(R)P0;?Zqa_QsgnaHb1VS+!vYBD5ssz6LiBWL#e82i%llWW4M!G=?enFOpvzW^*c ztyET)x}V)h3}F@UKBg9QL?W8QZAqqu$Udi2KT4wF_mcXtQ>T|J0|Lk0Y^fczsD2z* zY&=X}5e=ei6~8T-R-B2^JsR|PT{~{6YJ+i@ z%a`Xx+q3f)JS7pKW-BcRQ>x zq|8OHVX%D!iA+3TK+sSgJ!S`73al9dIjeU)7>VMqaUY-5B&sRQGtPpIK-AkvPlc;L zE2aQ5zvBEV3SicJh`wDcoi1Iv8D`3$8y@Tq`mX89+j*T7O4({L(#m!9klo?vI=m^p zDXx<}M!nw8)>9ln!vFci(5SeUR!cU)Vk79a{(!gz(I4>zdR@8{ZeMC7zL4c9Gj_xn z&0Xp~{5RoPFF2D$%)%fgs&d}VmUiymUB1mU@{>zKFQ3%DVFl72GdV5Zz6^&`t#z+9 z8jUmkyJfA})sd6W8ZKVkow}_sn?bk6r8R#hvKu0qv1lw}R!EaxcO&5_C3tKfyhz7} zNC$6GMsftCCqy}FRpzj_dmLH$>Bm^-Av(`CB`1V3k)gU2HwtCPPA@i!(J$X1=<>j= zjAtR&#(#*I)X1G1t1p$bzmn>}$6(_q|1VxX_liA!VlfYfSWI({JqD8cHnKJFYuK_{ zcj=l-UfsHPHF(SC6w@`gb8A*X^rhN9S{I2R>^XQ_)aE)av^Y*(yQop^2q}X3TV*4xcgQ7IzlF=5yV`uVS%3>STZw2TR97utQk zf?7u2Q{wpGC;Gw0H@!|^$wJ~qPpH50He;S$j7$Q6AHBT2Rc5Hrit2O9toFKa(Sc#6 z(zM*mM+{6|w7RoG+l7Kc~o*(hGCvz_nLn_oDGu zzHT?VyA+6`H%EB1=w|cj)26*41wcd@PXEOC4|lB7a|3^r=RM>MCURTVJ3eAoC8CD8 z^GZ;?;-c>(mGt04c%>ee3?*br2Eo+y={6-ZA!d8?ZNv%QxL=?u`3b-dk+ZUi&tWDC#XFHPCxr(Aosb`ne-t8>Q+ zE6>AaM&(zw7>6;Re)s+EmA>6LT1pkzgX_ZwCA+jc)%i+~dX}^Mq~7OL{`}!uzlF~L z>66opOzQ@^e)btogGE|mm%6sVW)!0cP_`^QGae`r)L4|HU8?dyz_t~Eya<6S>s8j8D00d$bdXS`1yO(l0fe=akXMIuKgOPqX1V@; zrllV+^pQl!Ih;}kJRblDs7pL^{8&X%#2HwQbRjDf;pl%mB9h*kQ|K^398ScTJ1?0k zkcLRy_6RP$J-h)zUT>?OnQt>ndPI!i!)Dh(%rLf9}W!*27F3Di%og1F-SqBffX$ZT`lxRv@sk~1wa{E-yO-Uw*fb4jxOS)Gk}r;@*7s$Xuqc!CH+|CHq7y)B zn4vS5+qwYQuH)*s4$8`@=(wbqq*S1dU|Y@8Gi}i*q=?J+rAyK-QXGOmHI7UP2f2-a zH6+)vA&iY3RFDyXszTxs-7*jwQ?Z}5F1*kuh4pn?fR(gaj>qmQ92z(lj<}-<)K6X) zpBNbZv@KBerRvOH>io|Xz_-~RM)ZH-fb^bNkm0DFXnlt z`|XdWC(vaVPL7dLa3%|aA2Sf8GQol@0MLeQTK%%ulXWEF2#t|3s%OrbowL7xuAURT zIyJt-?_+8Bb)P?yVwyod8Zf{heaBBVm>5Jym}uHKponq@Pk-bd;wI2x%3EBNlh;J8 zh{$0@jn}rpgmAz3R%bVPJ#C(@7TsBA0_3gFdu)bw5fL=ag;;G0Kq#s->`0aKmqe}U z2#e3I;NUL*aJbBS`@bm%bggS}e%{vDjZPGu>xBLp@B|QCnam{5=oTa5Ya~$h4*+Tmi|57cM0XvN_AfJJU*5(3VK`Dv9gVsYAp_m_lF>;-ASjj~TFy z{~uHD0gm;*|Bp*Tg(M?VMmABDLPpt=qU_N^RwW_HmIg{xBr}vK$zCCwWS5yFSy@Ro z`9EHt^Zi}d|6JF(&N-jM?Y`glYdoKi^~8;h+g{zE{a&%Cb1e2Yi2lgE=O{1dnNnW` zGp>$9?!~Sfx&Eg<71RE}{TTrZ^C^T=fm7_}*Hb$7WEeg2t+vsh5v40H9O?!CmZQ(g z1dKwEMAL34Gn& z4PWdJ+(n42)BL#eDs^7?jicAuBsZG;JT+Fpt|Sz z(Os%~hzz>+J>R|D3Hb)hj>t~KqelCN`N?(W*Lzigc9WCQB0U0sl*G6g+wO78RG~D* zQv3Nvu7?dnYLZrWJJ3RMN(ct*cTgXgoP_TZ4cs8Y)Q9lTf28Goqx{$91vPvoMvef` zCh`M3oJ?7mAa0eePFx|#DE3X=A_Is9tV)>w>ntO2a-g#VtWItL9G%aPkDZ6X8oe97 zH-2ChPtxYeveDB7yo2`3G3p~_zZ_Gq0%ShK?FmS-3S$wFIE6()vYlXb`ozCXkJ}j& z9C5VzqD3cKjiX2X=!5`5wZAp52PsiQ4j*At6f}z@!^k8Jw52~bM}2!5b|hF4g6I)} zR#)}sf$4`3K%nWtbV;YxV2%Tu?yI#8t+17iw`0?v1>TRg6#CQ;s7>9_cLCD4GH-!0 zjbt6<_g+Txr-kR7M2}f@sA6e=nGV7FU4@CCse4v@AsiHxmpA_?4prl8l}dJfQ_~%^ zK2!vYbf!42r4@L%?$&N$dHEQKwc&jN7+>A*=!n_}obGLl^egHdLYXl$fBua0n!J}?3K)oM@+>?m)|s(F=`s8jMA|ORzpD-H{f*f5P zPDq=qxzqYDkg*ER7<0w=mUAH>B>e`^UJ3fO_9n!1N2rWjXGKB|0+$Gx@It;nq?F?P zLWmpR{Sz&lc*;V!D2mmHdYeEre-~sMh;HMTaQ_gztes+-dFZmAw#L zV-^^M#GzK*X*_-oIKj~@n#-DE5Gz7?e&W;P&IrRW3`9JU7(K{!VM2ae`p+G_5c#HY zjN>!4B8&&PN;RH86dEMs0hVB$dbfatJmA%kzg*S4*)lw{6(!Gs4zp{Gxkvuzzl2gt zjGpl3BvA=vYm)Z=rLbk@E#UR=w1K1{kBYd6QSw4aoVzi8(=Ps&)+>;mWvVv&tXmJ< z7~B;KYTS@$Ml;ZZliV{{&>vUUUL3JleW->A7>l~%5gq`P{A5)}t6-ooK4u$i`VH2Pj=&Zz3H-DM3!<@f zwWpNSxlrZe&h-+^>$zWNc^)M%GE>p)2GAgfGkE>%$vDafi;9 z1iN7WYwj24?he3EjJsG0k@vv(RN#hLL=TP2*S0Oy9`*{@b#g@g*%+en^N=~V^+zTO zZNdeC`h*GGqo(G)apAtuii)ZE`i~VPvXGr!*z^VqGc#$fM)z(jF2c68R9ol*U+w`^ zCIUp2InzC-u^VILXj}T*kM)nUWgk5E#x4X+#@EqV=wcr=L+~XaGC(5bF8K>ar(yMA zoB|2;00w5YIRkZ9{Q+UgS(xJhY7M~@Wj}CO7BvUSFUF0G+kA!#We?FAI5ieB@hXL5 zN8p;mSf}^X)jI%Fxh#m4abf4EqPa)elKCtpD~sBs5h5%js1OOw0f-7WGpQB9S|}A6 zXk#>JP!;daHMGBXD#}Q77WM~$7MQeD9_~RY4`}G&)2Bi`31A=ZqweYQyNo3?p@9L= zBV|~6xQmLMKe4j0k>I5|dM6oBu>ec+jFkz8S1(bjqOVU*WiUY41q9i6C&DS3@IS7uQiS|JEr1obK`kx#Fd$yg zN)cx_Iu`Yk^wd<+VW8h7#B8C9D2}L-q!|+`Q!=5V>inp);=SNWpw0x29Tm$oNrSLN zKi5)BT%4R_&`FWff`sn?jTh{2L4XQwW+3ZXwrv6stLw86u;~ZzB{K8MaziFY3Qo-v z*!_kKGAzhG08BtoITUbLfuz-W3b!xd{z32dWRF%J=60c7!G%Qvs_>Gw7Y@bf&{6w4 z)Q!w9gMNJ^O+W;=QeMQzE6p{b`V95#SpGY@TZRFvR2c{XVGB?oo`r@=vYBqr%??~~ zQBk5Dw>rki;RjOU-K8FB_Q#BjU;QvQhs?78ApJ2Z1?6@IVlL{qIyf*X3Y{9>l?KMi ztjJ?4W9Q@?mlJ+23e`?#g+UqEO(+N8<=}I9lGX7T5GthVXbv?>Vqpqpa1qki$8nmB ze~9b_h|IZ7XOJ-SnCP&%7a#M`V-7$C0{#X)CSesw zIUX6Q3FZiLOd|0F#25$wz*Yh-Hg+S0jzzxIAsm$um*ZMP$*B$Z>{qEDo~RNKXi{#o zKNBAce4lt<8t_Z()akgj#W*1A(9!B4IkZcM+GT^vhh?O4{Crz9S>EsX8>)8_vE^K= z_;cEg)Z-)+JShpOCj42v)^cxxn5vKa!SFu9Se)V?FmUK&fA}Pqa!kb36P$YqiW$# z_DAA?T} z#tn?x7_xdvX%BV9$4SQWo8JF{d;*8J$%u9 z&Q06|$NL^~9Piw99c#eN3B-m74WQvgXz1H|^)497~E-;*_$=)I84Opk3I?FJ5rQN%C_v~aW@A}GK>QiTFcX7dN z)9IXm&hCE4T8dz9mA{J~BeP2kF}qQ%Vr9=biNVoXE2fUe3E<2+tH4}|!X)zsW>_Gd z&EcAcXy8#4&}0NE*un2;h{?@{@}S;s%|$%hf(z*F@THlb^UH&WO~hES^YFU(zZqVe ze~(k?2FMVm)Oa}$9M!j%2b#NZDnl1jmuVB%BzSMVoYXE=)`yU)4@}qg1$|0^Hre&T zY272lVFd})$2VW>E-6!cr!Y>L8X9)=&#vY`l5DYK>Gj-v2Z|y z^k+Ru0O#ZkZX2x}b^dsR@fv@s8G^ie8dA75sgl|n$7cUP15BEE0H+OzHSQaF231Zl z;Tmz!Iih-Z0;XeYNlTN=y@xgJ6~6Upx5I?5V`2Fseyj%)5>r#r(lZ(wcRd7)dc`~Z zaTEj4BU6rmsogF5G z#^imu)2Z&~Xh6kQwRppo!q76&l?wM5L;K*4c7$D~J}X`_4ww&B|o{ohA6)xVjiJ5Ei zor%7p7};R$AZ;3T@gAL7hQ(j^Vh4RgQ&YrSKYn~CtuH0<+Vk1SXpH8BTM(V;+EV%& z2?{1B6k?WQ!S2=B1A`!0@F$rANJ4Ml>g4rHegW7JNW=hG;b6QSRIN@j=5XV69=fvU zo;MHlZX0#Rg2n7_>=m*Vr$FbG6Yv~_1`a$?DeF%*GdlmNS7Fm4`cuSFV%UYA!DeOC zRke}u%XvFqf9ci{&x;kjQ_&(|!xQ$dQPq9V+}l^Jn05$a7l*R!XO_knSG2otRPcx< z%eD@ij+G4i8x5bof)e7okg`fu2~;@>3MCC)J<#lfK3@C}3vO*CqpgVS777k?j(D~) zXC&OO%;lB7z^zR%IWjo{)u4{c?$SPq0!5XFwRSllrJtaIY_-O-$KGXRV&Zk-*!t=L zg7aiUzi+z#;KAjJ*G{fhU0nc>fw-P=a$@6N6l|>W=1xsAnW|}DHpq@e9aP~UeF~mm z(HoT}Z0--#5=Vb)v&tUB=fDuni25p=VTAL;$w3nBSB4IjhL)ayo&%fuBJN`H=3wk2 zRy<*46)Shx{cy%*GRa{&by4qw-phirvV{%TS*P$8`TFY4XP}1^6P>C0mJU<{cQTavb%|H{zerOJ7!Sbmd#KQ&BfIRb&QKO4}`2&HL9j+V1 zttLFQ>A`>>kblYfooS|{cQfoJf6Y6hJ$P}HGX2TV`Kk)cZg4!H=R`3&@V zAv(!p`x$A>i!4wS0JvGf0yqd+5=jRnt*wOxxPMYTMMmC33w5w`kzS(suil-*8`Em8 z3tw1Kc)O@@H78`s&aZE(#zA@pB8KTlaVxioT<~LV{nZ3BVXDUD%|&q){ok*bF7G6= zzm1!Iu={lk!xKg=d*z;s%)qi?kH4W-G9tqU@!Vy|)W;UP9B46pPoCE znJT`S(cyV!Vc+=u*LAm%L^H%zy08HS2ZZ~7vf@^QaTb_MD;nB^!!~t?x2gL?D zg+$cIccxs}J*;W`L&vLikRCE3r#R7YtVq-UL2EST^zsVhLf0B=C$1 zOVxb}BPKV5Y1tV#z!qJf9fBoZqI*w)x4p@ ztkGqihj%_xp9WZPxYE4;4SS!Lj3lNxJkGGBwyAk|ZSigIa6<_5^zv$q^@`cOS-by& zDMIDjQB&0ex$4PYM}jJ#TO&h3cIKYVkDfOFAckQ*a+o0Ipn8jO!}A_%b~k@(6`#QON~Zo2MDJj)^EPk< zDEm!i9iebUlSw`^VO_JQGpbTVRYRzyXpj4ln^#bG)vq1~ghA!;7@IfZ-0v;N;R6vuGAvPG``X0qRn^n5`@tHa0 zA6lRMzn9!3W1piM(68*ptZJxR?m;?sDP0l4cwkp%uNefHZ!9>!pMxMdfQC7s(Uu_L zNl+z7I{t7iB1m}<>mIv`=U?;(icV8mIzU?pP>(I>ngP%~pVZXWNWMc%`#Fw(_hI1n zncA3kFC8j4i5|(jfkEhj)WR#g`xvbT(RvdfE_@u3myB?ypatSvtNQXZl;{wB1k2C&<5P81yc*OGg2$kwzJ|g+DQSkTHeU zL-gp92M?;!Fhg-<82%X`?Ich#B1G%GTxyxHL_v9o4GQ8vn+v~{jm|8N^cEciHU$|R z0n89P3_NsQ@tBC8O;tA(nz#luz=NX|hiwOHB9d{3ks<}-(q|(!wYZ{fErC6fXs|U$ zkNw9l?L;>6uXhvX|9dd&R*V>ddpcUsCa=#v^o+9c1qJ1T7nF@SuON%7JI=RM5T*T+ z>vJv3YVL-b+uYU%BDIm{_wnB(M4Ks#@fsiU`CP8f~{u@G6BN)6wvIflH$6%C?QnEJCqxVb?|7g-@wIt4&3}4R#t$!eo0b39N2vR~?xiT`u3`J{ zg+Dwg`y#%UJ~zFA@xvhZyAp4+{AZ;dUk>N~cTG@G=5td0y0D&{2LBxil(atK;W{Yp zD#SV|{#_}bn$NSZo4ara2lCYXT64w!?+3BoHtaO;ar68>OPO{_k^X5gqm3}f|9u*B z8r5}qG@r9p?2jDTv$mPr$6Fr+4vGz{Eq}S5g_I3UgO&g*je5L4>`*vr9^^PtpbS~z zmhUxvB%Y^ccA*fjZH}v!HJH=i%|stHELOI=-^xVo$lKnT}HWs1HB}4@E+3$BOus@@3%Bhy-VjDC0b)Kb3@h0c+j-rI@ zLoX^5D-W>!VPKO}j|e<=Lubss@JE#MTyR$DBZHfh`&^wHsy;1NJ!q2nt1(+1o0uB^ zYR&nyMM_4F*S9JUK~>wuxDWF(C7X?J8oGO=)wHoU9thd_;KV$nKU^FHuUchMm6=*FDZOq_>(`yN^M1@~$?j;+kQsB)>aLcrLD!ks zyNBHcrzd4&-YTo7Y`>GQt`0?-PMU$}&Jsf=pbzC+M<@2Q6DS(^Qn>OQPzLWjjW}8Aw3hA{OT++=IwwMQUEOoR6n_)UjU5b zZUIbND4HAoFFgTD3P@3jRBQ3d`35c&B&dlPKUZ~0BB|P0wxIv^4g!dhEIRz?+v0fiN6=RZq?^pqKN!*n4z6GB4mZ~3macxV5GifmoxEtkEb-@IM zcbsXI=?uNk>+t+q@9_DOwt(@JYh8wEgWsxfi1=sC-|jrYmNnOCHFG_{tmI^NJB?g< zKy7-E@$kj;-JU&)er2z$I`?i{m7KNNpB8hE<5^5YwCrGo4Gv2GmRDs{GODePbvXv>} z`CFZ`bB0%?MR!_vrrOnK`6*}L3b3Q4|MJ7Zz{cTZh8%Ntbf50q4*R{^_uK6M?kHqf z?K^cP(lk-mKh>`$F+pfpfqB)dd368NhCLT%+)jIx99t8Jkw@4jBU1uT`AgL+&CQ+h z&7Da<>_t*u;w5jbcWJ z_VWzXul0jcm6Q{YJ~FSmzM*<}#@`~jjPuPAiae_(7wY>%TwQY)>%_J$yt`$yk^T$g`IhglDW4C%IyAn?OzEt4dfVCk zwkJwkeZ0an872ND_& zya?wiNJM8ly3Im?WG)zva-l%Kw$X9D7ky*1AFpd--YMEIOe}=87loCPSh#f4W&ke) zmS`Qsom^lBG8L3}%cx1?24D}4j_;y_uAy}P@!y4gDfUZO4#$1!kQ$UA_EKuoI`|X7 zkO2gsxC0d!3J4gjGY=0Bc%yOGqLth<=tJKbiM7&hy?_c7Wh@$ApbLaUbS(m`oe2Jr zL(H&Tm1{Jj7AQu#a@IA6U_=# zVoEn|Eb@CRg5JZ+ld2CV1t}rlnC3CWHQi~jhIWOWbRDo>gn{lyu)3@9fFhU5NpheO zc@TaTUNNA+_E{3p8Uam)5De{NbG6%Fx5Hwn&X8Pa_W3>9d7}MYTlQWH0Y3%hf=iw6 z<95cf$Hboj&(CwAP2v$jpz~dXfLYC$_CuN zeO@|ha2KS=Pqkm_3~Q4LX$#J~`1pA7^x!|st3=QPB}W!IC?Hn?R?8vm(AEX-lnV!2 z!NknW*HHvfIw7zw-Z`ww&NXVZi&wz`o@5MC> z*!U1Q@Dy!376B27ya%BM)I57Rt!F|F!N6HvP>hsYN|MowK&YWTy+m2fdx-b%(?;jr zJl=9kMNPxFhZxz}yM{{T+RlL?CBYA>sw89v>Un6iu{2h^)^7N#<5B`knv&%*_a-YH z>A!&I7SkGCVw!gXTYV9HczOP)pnTjn`aRd6YLa^*gPS$6%P83^U8F5cRja`-2raW+|{$5H4dT-oavf zL^zBddM2ei(rW6APG+3}3UN~j;|Bnb+4s`%f#|ro{dE z_b`Qgh?|ORe=(8`ADr^3vNAH(*9sl4^aJnLCtl{r?%Ec?b5&XxZZv~jI z`2BY0QP64n-A_%;K*uWKFReZ2j~!hqEmG z?Cd7>2MdeKB5Z?$CE}6FsA3#hJq+D8X}7#l@bQ_8xa-8<_>Su(ee%1uE7oc26j zTzzBIqJ1H;m4BUF2-}-}liPIrB>DP$zI6ZKyk5}%*EYbf*do{A>cISU&xd1YHl%j{ znNXuUD0w!19lgl02vL?dQ*>=-{iQJwMjA6nF1U7_dL5b`VLYdy@N9Z(O+yuJ{ z7EcdZ5=3bQCO$$a0Av_c10?YbV2c_QC>WeZei;&z|9<(cnI?*BBzURn=A)?0AAK1G z#^BPVRbCHK41$OqgwaDlR+a-zHdtOX>_`gL`*k#pB{z$u7A~6^hzP-uEsmb>{)i8N zguVRJN>v9y#-jBi-ZG+Xwta(siKupoT4En&8bVkEG!B*;e*c~r?_jh1=>xU zMOT?>hl>!%;NX*lm+x)yoXH!5*UpD=#uE(ythd*U*D<2CMLhuqgb!T^EJJ8oaCL)@ z^n=;nA6$p)O(V-!WIP$l2Yvl78)ObhFqfCfV}SDIq6z6>wcRAIXwYf(?7DAnRU94n zf{VS0 zPBL$RfCFi7$H3>}t|pitAumhQ($fWzTSrXa@O?qZhUhrVS|m0|!jbRiXcEptVwo`vG9Uphid0x5 z>WMp+0LCE42$x>EGAIFM`pf$sPDQ=ANUy@Jb8RdwuN68uJbtVKc+u~o+96Y+VT|0d z{2XEJ=$;*6(0mjZM^H}SM`X8PLc#Q}*Xg;B>OnaZh!Iw3613J}tqflU$|Ry@gs2!9 zD{jfdC1_|t?h_X^)({}!^mp)CinjQ(h!fUZeUYMJITVZES2EjSrYQ zip<1qp&$?^TqKagcB^4u$3OEx|53|HLQM`GIX3VK2zaVwhL7m$qv$FC6iS{x#hHKJ zbKm6hQ)SoIyrZvOE*+KU+35TC78UItMp+Hrli>;ZLJ1?O1F5x3M-&TR7kP=<_n5Da z{7Nl){k>CFiRy_!@q9SdLq?6MiQH>bGNpK}m*%;8gvJk$~v$VHhIPGi(hI>O{he$|xk(>hBHm1xP{T4cQrh ze*7fK1cL0hl6cN_1~hb;F;WT2_3t?HtiFHny*-eLx*wrBGZ?Jg?^F1TDfOE|xf%GKnLH;r|{f`ju zmEC~}zY$Y316Pu3SE8lyOJqRYoGhr>u8x$4-eHQCL7!ZQ>FFl#f{mN>LiyGuZ_^f_ zciZfA^zvI34O)jDaYSLr=cfuW?h=3a#7aQkJmHP@^ThRDlf|mxJW_3y=`Uq!KRu?U z+V9B5EvEhbnZ4NGU;NWu!{UKK2WWhDxhP!qd!qTsK}Chut(YU9bLXD0XtA*7-A}hD zj(o4xWsjPm8LGauQTnCX#i~(e$D*=9y{odSl9Jr~RY$agD`~h29XpHEES5c*^V07n z@QT*e@$NYIF~3(ay(B`iSciv)cblLIyY%Oy=M(Po9l5Z3yh~KiEaANLb@9i2>HBtV zx?}Qmf0>&5vyEBCs=ECnx*H>Qh?o1nt)*Od_wL@S%N{h5>;g(J(vI^p`uuhbP2dgf z>^vJT?sE0-)8M)!`E48`l9!&vY^~U%!IYh+e6#4N<2xmjXBthhOrcwOkCiU0%P(P? z$v@Z3y;d(cGOn$Z+>Q%V>F+wnqhNZmnc$(msu|twlqH7XW$@xY5Ybpx?v?(>B=e@jk^&> zd5!)}9WLX-oA@fDp6+kaI>{LQn)dtEGck8POSygws0@9&+A3(c!)RAXpqR4bL!~3$ zp>MRo;0T0!wfNQpI?!87kIdbo$xOHRQPhQ_x~D}%(;hz=bDmr)$}~=M{2d=M(c`(W zDl#$g`@{Mp+WA8wZlXo|P#-i{OC>xDK?sAgvT{U4z}r<1_=u*9`c@ynn}FL9Nt6hk zBO&k5*W7?_gp3U=GPA@Dhdz`Z%0|sb%E0IMc2u8$@PgCu* z7U*QChhX0FA!H-*W>Zh#wjl~ov|X8E*Aq@33#$aESPi8c84-hgEJiYF0G11+(dqPk z*)qlEXALD5{EZ1h=-6>5lG&WNeQEv9ao<9R_7G0d(um z98BtlA`TFDMhjR+oA2JtKxWsQ< z#3aCe2ye-m=OfK_vEi92>f_|Q6R*!LnE}tao8VTZRMqri3mw~j!6V_KF9foBqRiGC?gUy`G7XFDwuKlnv?{g|8sp|fWdv)70%66 zlp0~5`JI;RKG#jSlJb+NF>!1fLBhhn+sXSL`tGZ^$|Vsk1uGxQNXWt}?;ch_j<3X$ zr*SwDmJ_ZP(qL*+U1n9F(PSz;s)xf^iRMaob91G^k}W_A$c2(~0k)w}(9V1RFhS});Y1;Lj@1Whgfs~QIMh^Az2ce zPYUq!PnU0&Xl`$}_Q*wKGw`ArE_Or2S$uV>-$T+lpa@fdeT#_h@#QB^-8j4pKD>pV zzJF>6%kWk5f#1Z^7+Uw9A zmYtKsPPYkWS6JWbP1TW%^hORpUGI01^-P2h>LsQ z7rr*cLroBJm93X2M$TskHp4uQVKZfqIgqL&BO^l|*!6EWY8;1F&Dy$HQD)1LdYg##9LE@{hR@%{2j?~;|yD1`i| zoMaZr&wusjrbX>-uRUAOqQ}f5VDQ1NWpmQO4>vZZFuXH5kdol@gNDk|c>TzJV}Zu? zeUxQ9yqRzGj9I=53T-s|Y!shLseIe0uJ)mg{i&`@)u9pH<2&T#dde53SzFtV2XkE< z{#hrj+$tceb^fQuj?hrbe+S+ z%d%1CL#|?w^;_84QC3!J1@F4VpOaTPqHCaajB;G=dzsgf&!yt&pVQL=Mr?eCte>!) zDZCf7dc&pUv>Yw(&!K1R7M3rroRN~=k#F|hF}jp9Cc?I`nR$bryB6I}yBAj!H&0*X z-MRDGj55uFgx8Q(Jq=}|x^dmd$0H*rI}1)(uX%B;Qx%=B-D==I%fv3%uBLz7=cQVJ z62;QIR9k5r`w<_{S9v9{Zv?ZU)Ig^yi-wY(SPpweW(pJ zfVI}f=45Ani~7}!Ek(-l-in)5Rau>!v_E&4E_&GdC@ED9_^w<3DR}#H8AYvQ>pk*X zlrJP6UGrkSO`rAfiSRzB#{nl4Cw86uvYpRlnbt3`q5aEMR$3c|pq9k_?dl%yHEtFO zj0r?XBo^gIQLekcbMHC7Cn)HC*##VrdjB}~;nhP_OY?bq-5y5%b>8MrZT#xh=U<WaGJ0BK?V~i(u6jR9dwhq z@iC-A{9)Y1a|^Vajg19FME5*?vY2z#Yr}QHE*geAjtU7DR-<}`Y{{|Qw4cAkoEZ17 z)~>6Uk~)8jZlr|3X{-2tW4W{qS zJnrZ*r=f>tIYO1CnV(eIQ{`YLfAIu&{A{WClZ!c36rfz)ur{%p`jjcycZpF0e2-=OTHnyW5 zqg(3Vc)!hiKE$2eGuJlHDW80Q@=g1gSHZQoIw}q|k+T{9F3vmxfP{{WlMErM&3gR! zv2#Huq^9udk?tM7vX@X-5$2pC5afD@NGuje>|+pw+qVGBkY5xMP&&I$0=)%e%YeZw+= zq6Ekn`P+~i!DmG#kSHwwK!(Nd{pM{v`UC@<7^;LA%9P)YA&s;|GB5QwSOCfs?OUSJoEOSEJbfi%emTQ z^H=LVWoF!qNe~dcrc_71*(9rNnoZx`Kn zN{A)b4D9=sg^-|<#i1QaQ5V?!e->UFg0xXuK`K(4cu2Ng8+UpcLQ~x9EI9gw^5AU2 z^E2N%($A2rSqWpoRrXb;H~yEP%blBZg*ch#K&yGw*kDJ@WrC8TO_qmsCT3b5s@eeg zSM?ll7(xrwAN+p%p_bjBr;ua?ktrOzBz%oTX+Ty+z8HWB1lV;$E(5aEC4Xi(J2u$m zaEwa_RUL(nf2+J(0jFN3*nn#hy{$%3@4rN`jL9Y4aI4)14X#hTTF)jPQU5S~grTyv zCgbe{{dT7ZVSH%%*DmPYP7p%GvX;8~ofdV)+}opLvEmI27p#gL&5_~+xRcA=@UL`B z4-zl7w~F^_u9!t%Q@T!zuNFn^??n%&ixG2B7_90U*J7m#<`u@e-nmC&^{DLd2k13* zLF7Qimz;5ZpQ$z@71-0@+CC6EYeo-ttiJyp)we?JgUpp!hq(7| zpTaotNKCI=?@;`C$CnKkY$0FbTNI^`@?OBCt4mCk377L8P_sN3j<+u_rRl)l7)+<_GgV?peX?ZvG6qfD6KA@KGLyI? zd6aL%Hv*F^@diaq2o@K=x0teO3eObV21D(X*Ss8@%0w$e?4RfZ1>vOw1jv4I2x0jZ znoiGEUIHV8>J+Xq$fk_myX`~!qbQP=8dt%rT4pKfB036XTWwR*Fr>g2k*W%QlH*Uv zQo*>>Fyb@S;qimn`UL`(pj(~B&&L=`v7DItLB#EH%*^Xe75snM? z?=D;~xt^NtJ3LUZS`hyf$qmCX4wc~r*vv@lS~OW0gkoR zc3(u$M5@8(gzvk0bpcY~=eTc>IOsZRzj}CTW-*x4oo#9lQjdvSALuH8HZuDG>hNJO zr@)yB5sj4yLE~5W!AU%inxo%kg4E$HVf90JC@yo*(pY+~9WFL<)xb;>aW)Q)M@Zp{ ztS*nt)D1eVLw)I_c|*5Y9xZDH>SVe5_U8Y7K`h%l`2t_TSLC~pN1??10lG@e<2Fu~ z9o2gUS=9i{eSiaI06qRj8wys|>jn*N`4*q{fGKIs=5Ffy_ct9WV`LTJS^jW&o5qPp z8d_nmUK<|rGECT@Ii-wSci{ z-ivGN6a6c(cPop_EW?RF%B@>@Q&N}`X!-fQ!uMPUr>c8%%g#Ro&Vj)@>;=Si_;@Um zUdsAZ@Ps`)Ti)O7GuEVd{n}D?b&`;Hi_`l`!9yrsM!y_8U-MQ}G^=k5@F%7wDaE7H>|$>s;mhKDj%DuJBv-LEq{&QwXrjZl&z{pi5IPN{F_aa zh2r{6pDL?d&RmRpqxaA@LEZ0`*z57Hg98I6x9ud7HZ{H1fc0OQRmeJ)mumOse#V7yX6C1RqX1_rUThQsD_WBDH=-}S zH9h^_Z^!$dIoIPtPAJYW{yeR#@B1T+Lbrw2pZb}m#zWdoOtx1_Gd*OivwXV5cmqXd zXGP*XevN*~3%~=Vz8Do1Q%J;(Wm6<2vgR`|{o1 zjk<@uZIr9NJl8Yb!HvgoF;KganzT0bR~TX%cK zTiaPjKIc9A9F^;x^un&d(CqRdbJqA5GCmbYlBX}`1>L%>#g-j{u}?ljtWPR;G1X~? zg%?#}q=7_R^KXBu$fo49dVCAw5pzZOkCYVax$$Y&t2QcX&)ykPJ$pu(E&aN(Do91u zUc2a*$dfDd1B*M?BM?QuK0!c`cD*uEW*#;hu;;43kj(@YbLbeNIYz+{Kvu_ zG4lZGM<5O49w+`KH0vaM3cm6C$Kn_sW7zQ@5-o@+rt0c6Fl>aaqZ|32#gL29evZfq zLTQl?2B`r5hNBZgrr#i!7?FBb(f`3x5){dHIae_R9MGJ_iAC;0Yy2nmhnq+xB#aWQ zb?o>741SypQ@IR^io~Ylj-AbnD21sXqdRP_T)BuE_wdXoDQLKXeqbar?hC?$qR=9E z8_kwo@9p$-P{skJz#F|QRuAS44)GcqA@N01+5CejR#5BPpw2|aaO>7BMh*`5 zI@iOlb(bk78Z6piY=~s0j82OADSI02O4+TrE0^Gz3oJ#=H|9=LPy)!&=8{j zWFR9LfIv`EP+0mAIS8(ss%&kMhy>u&Yu5lw6AgxMm%|V;C#v5ge54%uZDcaTD}+>& zgTDY41K&a^ia$o5Qi8h>PZX`mekRl0oSc@gUmZT3^BkWA41{PKsI)XeQ(-@lmorZi zC%~Kf0~JOpGVBhK1Y}Ibx`yYjpz%nV0|1WnfXMSIhaHQJ4HBzt-OGTO&~^coqTW8` zRDh7PDPA%!PZ_Wta{kI-YI_iYWgf@xqKo7kc!+J`k|`b2^%uC?eqQ zZUlNOm$p^bCM9GUON$HzLk~v;XGoB_en3|H|7ii>vViGau7Q5FE6y(BHCjf+P1Ubc zwGih4O6~TePW(I`3Vx=@d|s(X6#hywqx#@MHjw!G`YhNF)#o#uDLGu)YLIEj3Zf2H z=U9R#<2!|pBEl0vh;Q0F^qB(yc0)kM4ru}@ZLrR929TLnpcF4qv1RAwA(9i$v)vjp z624BzqlKN`vi8*I6rwdr31=TBULX17NkILeq>>8EzobLK6cH}7OE252hGyN#tZpFM zJ_5zF=&W4NUHq7d04q<7hp9lSG^#cvb!ptZDVHFWi6jAlrsOE7Yitw%2LgjPMMa7+ zI$8uTI7DxrmHnloxpO5I205gupGApA)NgQdU|u-S*3KB1pB*dc55nkvTdbZ?3naP_ zKLQl<5Lgr7?W9XYoCb2HksUPgI9~>s9^nKKMg-cQRIO0@M{tf6N=+cLlN)t6-UFtz zX#+uwa56w;idg@v^?UMQ2qu;(jG4AV>jpPlR!?6Z0_TvS0qLZ(OgeTThywCh{kTtJ z=n|$$%5Koa<0N4au$L=?Ji*EHbGGA5&DgA~sLOr`JpI#fMmOP1K65t12rU;zay}qf zDvD~1`l3Cn(gZqcm%e@%LVQRrJT#|jaU|{mW743PuSH)PYSr~Hgo>Wg^Sdt%#H(mS zpc+LY9v^V;v*_J%igDan0y^XW?K0X6s2cud-ni?0;|0(H1R&GW)4xbdyN?r>Xm262 z*~ft6tAHRX$B)0r$$4T`A3hDf9%C+mjX(>58BXf(^b_n9Vu;bd;xQ?QU|0IT;0nZ` z9Y~5Sl)^E9#OBe_QIw(-ufJ%ag&}j|fRB&GrlvwAG6UBwhc+?mWnxep(CwwfU^V37 z;?x329K3xWrK;`%Mn)|_2qbaC`V?~jVsr3E?#m*qCByDZDwkDZYAShM2x)PSqal^K zKFPU5+F|d{3gSMfOK%y16)vpkgE_TU}52VIEG2#Hm869RFQ%?-Q{5WX&|W2+ zbc`PPaI*+%IYls~wZ#+4wny@mxRrq85W6(GYeWRDk)E6vaA)5`|QgjX_7md4*d3nQ&>-MDL7qrkTDu7TmgvD4+{ROLrUnZwdko7)-M zHm6nA*z3g0-z=tnx1p5oYwMAZwX~t_21cdd8geHT++P}PmnnCy>tm&&-OPsSwjs*X zqoDPEn{Zf<`VShTg@~xsUt1S=8yYCl>DXR66>K?nIpkxkyjOKd8^yJv)jPMw1+sj& zkDhw;@%ip`3A+zG?6&kXFIxQ+^V)Y#Ix76cPc2LP{p!zrQtcQrEAHg{ESC-nS1L&{ z2F9v*!}Cq&1SM-0tNyp5h!BP>C47>)G(^*=bJKgMDwOUC%2Qm_1vs$+RK&%6!H-|A_8F3eTmd`KJ{l z`FwL-pA;{z(4rqy@QdcKSJ91sZqff$=84y8c1EK5#F@^E3EX>a!~D-xe0r+=FoNxD zm*hd3Ep;)SR0>=wYFciGDR)S_Eex4%fxuurwV7I9Anod<}7wO z1$3JzhS^k~Z~5rPzb+(olj|K;t8^J@W~HjJZP_v+HW%v7PjVJd&DhTL4KfLdZa!(o z$<2GzbT{9|lC^IqPCYuRa_d&_OMb7?-!T>v!P_tM$eJYO>R8S1;(E(4=<(pu8HK*% z*2iVV8GV+PDA2MoF=ekv8Hh$79=zCKvW5SJz{e(`&Nqq_qY7Mf95(0NLZ3X{zBTz2 zH)C_O!t#8*$Ljr_+8sY;kM@;V3z%J~xXjk5&b4v8cX;pdYDu0*8xtIOuFKm$F!e02uhn+&ag&aDQPj1>1Nk_B+j|NR?G6% z-3Ed0Gbx`tXgV|6YF-|(mtEt{1_NZm%q({OBPcOS)lv_G#;?h+WU3 z(sjgRMKhqp^qPHrW*<&!)K6so5K?WWK48$jSWs0tzyn}=;NBk}uy2Re?*953$dd9A zkOu_;A@!@T`EhTL_<*$a>*qs8{_HK~bJA5jjAZYv+qRMT4;=een4dsR0iF+0IH_AN zg@gTrt`rir7%#Z9x9!*gAQ$FH5?%_hYmtm|!MOsO7in(o54~G=7yOkRxwE=CMwYF; z-cl0Bssn{7TW;vTuK_X~DE^5357!aqv1QqJi2HX>%r7rOX3$7b6{ylld?RkdN3Qub z13bu*CG;hPaXUh1dW~@zB1H`MhjO4AO#{;AfN7=hIY3v94bbHC;{^!aBwZ!i82S+9WF2_&Qk)9>h3;vN6=??}B)gzo%tt zen*&#><5e}7&`Xxd6$E<1%hZ1n|}NGF(g+|`j-K!p ze0-G}|CWGiO|Q%}?Z_k#;5}wZ!v4F1LEs1@)O@eXjMg`FfiTe_liZO`Teh7@c(yBg z9fo-VQjdO^KVQGb#wwNGe#_mtq|3n?uZ)?C!DJn2SA)L>*K$td$B2xYbJ`RHseKPw zvj4LSWni_CxmN=}JW()X-CQ>NiDgD6KOiubnUxjcZ1%Oq(egLLN%jjNfqI8_6y#$n zQ$2C=8Ra`Z->SO;jUbBm+Zd{Zg^q4^m*-zyh%n$JR%TX3$xnisNR~Yos@fz8aou%u zd>?T6MP#CHKjg@bR)-H^6X+bVC5ek212IHlGSLj9`~ENa*|$l<^GhW%5QUJ)Qv{8o zwSa_i_no`qQrG+4Z=8$TvP1e4MJ*8Nd^lxpKxm5ScU!@lN3nkwk#M{Xt|sK3g`@na z2+hPq0-4pDJ=trUBH4eXlOm`B;i3)%Z(}(X=96_B1u1lOutHQL7 z?Ck8lNSGrEBBZKfX0Lh}A7lw+klfV!83vF>B*Plr3sEiu;8uXD7E3`onma^(?rEs0 zhQ2b$`c!Pk1Fwh^*fu907E8>uL%AGRaU5qbHxSFRj1DuE6j2_0Q$ zkK(IlchL6Sf+R4S5xxm4$w;I)j+fK5Hl2&}oMQ+5whN~g(Zoc&fyrzyei5lDOyIc4 zzNWZ}j%mCnF9d=T5*S#)bFIJI2#wwQ8#6*8A}}VEtj?KM__pudZJ?jCVw+z2PZn}y zP$tu8hyE;$r4KkBqBEEb2uBVr_L)-|3me-hBO|vi&lQYg@Z%;84QhW4ZpO1!fh_n- zO;1lPMfRir;g**tgVBhZ9-sPLw6OU5^*%%UX-KOB?@d7ecGQSS9e;)f2c~x&M1rkr z$RP8#*cmF)O!M^2%n;|vg!-9KoXFn-&qlg6oWU_h&G?uEl17z`W(h$wP??hMkDQH2 zvf;V-I!de_-gk`jWd^8%lt@ctG?-5wC>#x)$8m5TbO#y}v`NsNgQ=Z?-r^6SdO$nY z8e`QpH3QgQAWgrd>Z>*1>zlfWR107?O%z!Yx%l$@V-n={k>>fU58FJPx%`I@N0T&t z4DSFBabM~(un_m#N@x?rXCyV9Fiix=;h12RHhCAeSHd z-oiKag`W)2^b`JrM3G{aH3FaBym{|LqjBMp*0CTiw1lT0Ykb(UuZ36s5r4GK-9V<% zFZC)n3POJGrH@X3eagnZjrX1R15Kvv{WMTPvaW2RT8SC7qFkq;(VW`an)*Y=dpOUC zTiPJV{og-lpCRj=*^K#v9R}4Qc^tDF8CW+mv@!ttD=?nTx#M%@fTc0TdxhHiZ5iGB zxik6`G~A7EuN>aCx*WAtRY+m-Vuv`rige`uvy)mr%bR=326T;!+K;K~`~0{dDIy{) z9O13Qe6B-0@=}i#!_L3`vm!s|M%4yJ*q+r<9XPP=-+YerX%@@Uz(GIx!W@=NN&eIN z^%XyUQ9Lq^FIHY(W)*jD!7Vf+aSs$IK{aWrJtv~9t40~87vpnvj)Z>^J^$HA%<)gs zr?`K~8~7=_yoXhWGtb1U`>Alp^Q+7+H6F`Xwc##E`0%zD-n-!!F{gcdhPu$R1lA?Z zUUkcN-c@Nx&s|u&)5+=6tpM8r0a+7n!H&v<_8u|S%Gz|C@*-K@rJufeDp-~t)%!6h z7}Q0VtW)>+u}!OPSW<$Z?iKqxUpWDYa16Y7tMk0L_}XW8Zi?w?N_KXEXq_~Rd-s@Z zZ3X7I9^RkY=w~o_JgWG#!IF%dhME%pwED-38PIc24ovoSZrsFL`)PmHy_->v)SCI`n-1i7>BQ1Cw>;<2UD&vh zGR}MGN7f1Rh4nGRMPl}2Mvm|9)};M#a(VD=e(|i5Z{%0A%?Y&X3_xA^ws9rP2r&;v zXd2Q|Qmr(%CNyUyYfOD(dL(GtnNbm<7%$`V|Izf_@m%)%|3+yE85xzmLXs%SURlYO zT|$y5AuCiGl(Hqs%*;-atVkIZAxg3`Dzb~L-}8N+@9&Ru9_QTM@fp|kzFx2ATyO3$ zq_MYnyw1*G>xjYGgr?G{bGk{6LQ=Xl!-mt|9kX|~Mx^tL7G0p?^6!5&iZ}z7Ty<`l zm9n}|q@|MS}0oF1GToqQ7h zebBmPOzRMo<(u+OxV>s?GE!WqJK z`#m%@k=-{Wefpwtbi;DE1Ita-`jf1*r%kvivqc0+2hVn33fnU;%)r6ipy7K=h69f}dqyfBK>Z5dKiDZ2fvO^Rb4zkrdZ z<%j^Q0!-1NDX~frFlZAvZvH}=A;h7yB5qLS5n>e&Hk`zO-yPI8pS{djHpt)NJ9iZL zas+<}h}Hr%$ppg8jiCWipGb#Bkl9%wCChsoEh@bBkd$%=mm%~3(@vZD0rbOr(i?>m5!AbG zfa3AH023{u8!^fO{IUD;^vw9z@+h2+&?^E(CM5{ySmO;H+)Gt0eqd!#HgMG+0t7gY z0E<;Ib=oM_LznPhVFcbPtQ>fj;h7hS@;9I~@Q3`DQH7KLe_nAC9)!gLR662c5lXo~ z+4{Wi-`_`(?g~iwB``Fa8V>x4mDR%wMW2)t@%lf0`xc-jc-8dLx5=BKXNwR{-i9pEi7T{hWexVMJ<`2OK&1bl* z(s4m$!KfZ?E=!tsirmV!nWNSo@+kh6eq9DfKVUTW^vU(DD!v6et84`rqp=BS-u zg+NxMnibgGuvbMWD{RupUXDY+G|_^no>D{%0oqKM)%ShbEhEz%?nQY{{xFQv$?yJx zUjWnoiqt*5w!?zSQqXnw%zd>P=Eq&*x`m4gq5@p+S4WNu*m&wI{MQ2V-OKqJrwEkDem_(cB~ zM2dHcq;8HDdH}9yn!?5b11N&FywU25q(yt;2pt{Wp?gCoKL2$knhoi=6qZXUeJ(5X zV$dX3AM6>p(4i|s%~zTpejfP>@GXmNzuh)MaMP{l{$MOHHW2`9icBcvW`i#o1dM%< zA`$bI4T`H=E9YV4NSMyY?urWmqK|{$UFM|aA*aR#hvEHBX&~%>#IXb^8x%anVmovX z?0Q;`c0X@q>FtIjY~6P#gOLfWgNu*U)`FgDfbX7}nRvya%hCAFzsne50S&(GA3t72 zn{rP1gMj?4?2_wnZDCA7N5tgz?$VwK0%Ksg>JoUq6jtVpVDnT}I%R3|9_uEkIBEWK zSigbOgEAqF)6gJ84YH1k5~3%}Qa`^ZV{o=z`_~1emUIPSRIk+Ac33%UNoHxsy4Mv6o02$)Fq*qv=Xr=NC2(|X-giMR%=D|7l5kWA}G`+^w z>Gt*m#3Wc`i8+LT)Wu>15dn;iu{0?{*F5gZ$2X4xezzw!{(JPKpc;t74UfC7c%hn+ zYFG>rAt!)$4b`(Gm_i~L1jbBpr?d=MIe-cx5myb>+Ym;5g#rW$nduKCkWA2ZH9u)@ z5i=T`q52@BpsB?0Ey3@$m&R9AM1&~bX@!yko*_7lfB*jWjcUGF;O`u&K$s4J*@qO8 zOFHq@qd!49T`xB`_wc+9H-kh&1agG2SZXAUof>}?Bw-#)`t$%X+d(gbOAlHkOkRDMR&Ro?!C!YbT5>FhCbEk`r-&rBmKjikn!vohZ z(f7kY@$$r5wV$I=AR$q!iux8Qs$dGJ-i}xKAA*WXHP&IB@>uki3j{DL`D^BXXP8Z%KxYg|N{0D{Ot!Zk{)#DiIxUw{W+0o(Q3lpVXusIWa4n^zALSAJG zs~#I}ak{i2TN6GJP&op4q{5m2LXk9b;la*UExYuS#7;sh#gEMrUZuJeMQ~N5`xFKY zz#RvonLXaz6JPR$4GL+^%|Sgc&tZy#q3toeOd*7ydUbo;1WLdy6ZUF$^OM@2jg-x&;;_wauG z)yezHt)f%oWJ*kyaP6&srJvWi95@=Ou0X$W^DBA>6XwF6WT*Qvh5B1vdP*%1@TYgo zHnwlsELQ*eOr#|;2riTxYutUh|3py8$4H7_pZsL63@N@@r>sL;busSQw)Hos(i%JG z|L#0AD8SI_Id(m%eyJ>SW_KvfQ?DOFaZTa1j)(%-wh;xTK_#bIS{(JMIC3+`@2EK5 zs1GY~nR^#E<;yZocV%2`OKmOH>(`WYyCTaCS?H8om2~AJN>}NnDE>T{2D>dF$0_5d zcj^A7?upcGoi^6U%3xukLiby8a?RBMMHcli%2Sa_f7Yh-w1Ph!Sg#z;ux<~J|4kkB za$(k+??ejMg6=6PNk4Q{T+hY4P9py4f#*Hz8U5~qBp4h*6hKRQ`mK(x+?|TzL+-m1 zA9{VcKj=T>vQ2+_$;3+hvCdk;yF$-g%TpiMUHg#vZmBt5SjEpWsP3i%%~r|wf~4yk zxR|pF;);$-{V3@$j(G5Z;>g{%ZxvX6Y&I1zS)X8i7JUu)1{>eg@k$yhF-xCTvrG`2 zEcGo-2;IDAvn_M&gaPfd8w!1;<-Gbb^(w~EN-!zxAEtEq#5GdUymjTDC47fxpyq$^ z?uZL51niHGJu<&&KQYe^o?AN6{QOTpu0w~sc==BLIjTp){Pv;zuTk%!*i%AsDzAMa z9-8Qs)YNZ+Qf&SEM<#@*IPBFlryap|Q18@_`9^6Y{S%I}j$B?9HNPtr{%}QaQ@?U} zr!f#W5g z-o4n$nVP0P(l@~UI8#mIhM>IH!mXqzAu6c@yVZ__5Ep@LfYUt&#)#nmuLX#<5#)VU zDyM(?!p+fch>Mec|Ps_x|}_e=4*TyKfG5=Ykb`8qHprJ(5K{HdnR@L>$9{m<6G{?g9(;Hv;%_p2~`4V%1NvD^Y*^7Z)b>s9!DjT=;Kj| zh=c?_%u1vt59}^!ZpSbVac>!A)+E3|^k_vzM6{{fz{!YN0}ePX;H~5(pr8(;o}MQR z{Scf-0yKq`NEKW;0Sd{Z8ILyN2Go5d7mvk4!u6{R+ER1@0^tvcP!tX_xE+rd+ixOu zmUs>1DP;RH{WZX~of+MWZM9cu2q$#7fW>bfeP=b4j%fB0Z26H9*pI@poYf zn~piAx=TbFhLuGqu*nKSr>E}t2HUQh9$EOAy@=;N56$e zBtIriwU1w1pQDK4H4$|~@O`z>eh_ylX`sM@gt_E+b$~o>1>_h?z&8Pf_a-QX7vCtZ z)Q#{Xz%Y4M>f*akCF=z+X^Il_=m*6v!ay!tl=1B{v^^6R9oui zH4D|xOz~t$hq#gaN`jCA-{wGj8t~R;IBeMQhbZb9#Xcp79S;j1XW;S$c&&Nu*@5aF zgLM^)n*>+`HjJ<&QS4mLGs3YTL;qvC7XATgqJ%zz6k~oaJbCSS9T3ttP(PKyk+Fb( zh}U=$sgo~pus~ozGBpq(%XYwp4Xr!)Zo}i^jv&4TxGq^c{!62h$AOp%j2c>*85|f> zh_wOaPueBb)fZqlT*L=Xs!%ZRs~ib{T9Y)x0%e%NI!;=%82tgoLxs5}!u@OhhU|9Z zChE*9}ZrU+6xx9{QLkF-XvrH ze&Noy;8hU@Ok2>^jN2Y-6+M38D4=tk1%&mYe2{e(o1Kw5qa)Bp;=c02%;Zgk?;oPK2TYE;2`aYlLe-~AK9iPlI61Lf7D9MLy33Cq?5(G!mkR>W z4~gIql!Lh6FDN2EAK^6TAv`5pAPKpH%Y>Lq@sZ&9u!$Z82ttg{or|LXiGdLz_y{z9 znDA@;))eis%A)f|?Z_o%>1NH3sx@oY< zL>YuPKa02h5pR=#@wo7@`0z+bgp=j#Z80h>?L-eXi(VJeT0kN4eh_I&(^N}CTx+*% z3RpwdK#Sf0oqVKF3X3v{Ey9V5OSi;-5y+1_GMBz0VFW$NzjM0iXlY6I4}h~^FxaGd zd>=|`Nr5Fi&4g7aiG*d3HV0&_On6mXU%O-Rlr;9$iko(BuxOG*ry7yf^87-JlkD_P z-s?qZ5}+2dpf7LZ^oWGp92u~J%c2P%^X12&GzOuJ! zA+6T-?zbZ# zFwZoFVjpyXxcZ;%pjP77nx@Rok1gtPprYb`RaaWC_Ddq3Uqsa5>Q#P&N>292#vi&4 zdgG7tx%Yb?ZDW+VAY;$VH?Nz>sHAyI%_QbtrC*}AoObaiI|Et=2lIw6-_uf6?0#GA zuVzp6KKXsrPduLc>?R3OmJiyzWu=Yh3?hI5Q+iD1JJn70{^C8@wI?UB*Y&aX3%`TA zXsCA{{ypGbbu&Xyy{`J0n|$$yy4XBDIZlUb2IhMP3ay8l^&&&=8HsecMQ;m~KW&#p zAF9J|-&b0?S3+M<%Gk$iPtT~A>iNRR?mKrs>064HTuMs#$-(7VLdEbsX$r^Ufk*dG zBXU?(_l`l3FvnJA=7k04>oK=kE0-BJ`Di(>?0jN67M*c=JWr5{s{iYuy*@qwIfesT zMw(;R1OyavoCRzG_XM@*KAurj?7OYO*OqdPro>#OU_;!JWEmreNYTa2#BheK1`!eG^YIkW1dR<((oGc8R>5QVzM*>@Zp4k^l73Rss~Tt8WOu7WzO zAy`1au91loZj!j_7hh(Vj&0-GJN=DH^n;wkKt{sc+(g=&A(P{VvU$9Ie^&m|J#)SE zZ{Ll4A$xO5MdLGhbi1NiZ9PkM&*ka5JW+EI3%>Dl+Gik5z?4$$+0Ft5uI^nDu21YV zCf%ENez*kThl3kEi_}D4->ojui-YNLLvMP6Z=QAPvN<-HFxO!$nlb#eIKTSO#*IN= zt%MpIWqlpSLT+Dst4SSSNqdKlXaBI*-osumKF?i?D%5wlJms6BdSq8)<7S40in_)e z6FpvDdpgcJLo41Wh60{L<~ncaEDpGt9^ds=Qt#%So;?Fq_Fk@%7#TP}`JWw~+f9E% zXXdJ`{FvwF^WMK7-pse2SJQBAY-zn9@gYPnQ{w30Q0aYL#+!79Hf=O+xL3D3>fsYU z5pC5AMfWB@eMW4cx>{NN-#$^sko~7=!lTm#4w=`ro%*KJWm2PaHY)gopHJTB(&pDM z=DoCaSTy3gwrHfy(1kmMOEJp(u}%3pXzN_=Fm7J(=j7n<=NEL=(*1a{?-R%I7=B^V z`68F{u-n%;bv3$+RDF_Q)ZyhAg9G-><_Mz#fobD#1qp4b$udUIejsrTiw zp~T6P;o(_t?A@2`LZ1HG6jPjlhB<5ukgAYj0+T3)YbchZD7XR z9XW^RgB(jOF39XO`|6P^=!YEgw6?JgG@yks9smC8hQ|ycu!G!#6v4C>44S*|N$GT6 zVdT#F;XJD^daBdUaNemCsc@v93Sh15rQa?K&#Eo58hePc7}<4rhl_Y{SW=$E3v*lczId_JD0IVWhci_MHOba{};22_nG6%fQb^-o_l&*sG^Kac@L z-n5Q}#(k<^*)-3h^XDUsq*^*UOf>^Q>U@Ayp(v5`Ya;p&>E3Dphu{g=5S{)SO0IPk z&It1+)kgRNF03y5V4pmUOpVJGG6`kmLK8k1M!F53he@AD}deeNXklGU`240cw zF|Huk@gmWB1}EWB)C_5Udgf0@-V(4ZL3B3b>w*M|I3`KD7Xsw?F`lA8T`}Xz&vw)+ zUq=QF^eA}r*c#lk3WL!R&|Mv#+bG#MfkzIGYXg&mnAIb^%Omh0g$JFxSrXG4rnlA9 z?;SC(T%QBt;zDi+^ytuVlNdIH@!sf}aKz6=94r)7zy7}LXnK8iD{2eB?*MMBFI{%( z)Tx5q2B{51MvU=TSWFE5(DM*qAc%&H3t-Avt1};5mrc~e>UST)8btCyzk#VjoQ*Y^ z{{`~xacQC_R@-R_+M>WwTp$?1A6(zU>i+%hHXd8@s`_XAl+IPR|KP_oh(Fle`ww68 zG$doBhlWuYFfj>}1KtS}4d?)HI=G03MsGf@+HZWG6zL&0nLyCk^T;bjc3(e<&jBJx za_=x5Mq%@V+{TiU+qrWHJq%ii_8_t^3_VoK`VMsPP`qAV!%Xc1ib{>@H;Hq>WKTvU zxHxAZgfIP>39KHKhVLNG$1ec+5eF|O8A{NWdIY2oD&Y_bqb9L!m5CbJY+xf*5Wp`T z7BiC4bS5X{78G}+2|Qqxc#tsX>tzn%Mgd|RQhy9ppYSliJ|A;B>2fTt3B(c;N!zB) zn;`^ylnf#ct?(Bi%R@^ahCEPJhUj+&R0DeZgQfHJOt@#k@4}n84W=m zVid(M(E~aFk;lFD*@E$<7I|Fvr>bv_tZr^S&|JFcZ(4^QG*qY(RR;F$-RPplj7;38 zVD~`NWPLO3x{J#J=ZPNXcHCYhB?sRYxjyX0r8(Yh)w-H}Cl9}tL=%ya>*Hc7yjV29 zk=jQUSL0X*0x4k8lTEw}7y(1^Mqxu$MAjrD+6S&LeRK2O-E@ebsqVALvO!471AHrO zRhRt?gwv_4oNd=4^(nKESY!T)Ua7)>6<_$uZ{5A? zh%*vbcfj)W5wKxm=}dS4pnHB7l4o$fkfcOP!b6=n!HtUmmkxQOpt89eY;- zONR$f4Xk_9L?_QP92zBD`%*%bU` zip9uxgNEMovWcm^6bR=y0lPd~mRdUkMf7-S!Wj%Ar*#fiF?6%v@*OMk| z!*29!9T>3q9e*Z6hjq_^1YOtssI7e!2*wjH*~R}jR8c~+dATd>Yjn0M&7a?%`vR+C zZgSG`4@FiVx#+VULMN?odM;PH8GZp}8~y&F=7J~nudN-Ak9ja&UgEfQyWOtz)r)yL zN)M^iJKh%G`5PI(m^5-+^Vr%4Yi_O0-E;3NOq-f}*F1kPKeybsC2+!IOHt8-wYTdG zl!Ml9$t$z^X3D5Tmswv&Z~nHWW58Y6H9z%w62sn*zL3xkiD9q(`}T$QT@hgwY~H+u zcehsEn1Q3X)zi&=S8e&M#PQqRKTJ|-9_`dOqE=iT`#yE@pt5r0l(4~4x$1^pn5^m^q$Tr8g z{kGbPy0f@Qe^!m@I%##ire%S+%) zI`+fub7l**eMqZ*NQ&yw=$NIqA1?JtsJ6JMcq<`2`t#@kukW%`|F}yl?`9iOk9|7& zCRy3{nQNzeRZV*iy}dd6+Glf(hA?SHwvCEwWmj7M{yInN%YmaYs<#IMA`pPQoZ&+9q4){;@kFz@{ zJpF2GE>?Ip4t`dDuv%l^tlvDU6yHe~ua`)xrl#4^Q4UCANyoQ`9}r-Q3zw%vQ6TMOBqqtS>{u49*Kk$@c?{Cryp7u5Ymv$nB(V$7r#L zRsa6sOZizcw{PEOx>57|t2>GRadJ2xm`1jLSWL&di};|^0a|z-^Q39tISDp$u;3J= zjIJYiE<*mV=9AV@?R@3R=U|r+FG;B;?Ff|>K*Cdi{)megj$`EO$jFp+Dv$U!WNtRh zig5s4@eo6Kr+jv&5A0~5X30|xPk1y+!7glMY+R2Q0E023#)N=?061F)1_mV@n4qg+ z2&F|k`Dd35JmlXc93<}!)^4a^hhWe+Xng32ActAL)g~pd>5WZI=q%p6%f+O)#!2em zE@7pJ0EmdxGuz2Bv~eS8TLXeb6idMS@bTjx==0B+EQXQ@<5XT(ybrUE30F?PzrP(Y z3J27D#lT3!3y=dmM%CUy+QOU^&Q!|>@~akY=tq2{b?Ou!-W|oOGz3P6V*=ifNT&!k zIOT?K@2xrkaX?pxAkvZm?U)ziM(&{zDEG_zn=Uv4FiPZ(f9Ic==;)0FiI7SQHA1H7~;`08y_w+ZBW(7U5w|C@=uO0H=k-#r2GBa{7?&3@qF4 zK1Oh)ctiLbfnbW}L*B%vYVgv7=(H?fJi11>z7%!$ZJQuhi# zPZx-6y>jL7ihz(1DWM~wfg(^#z`Oyf{z@O4SGmS!amtZ-0M&xHP*F6={iR!SQzFsaZ7Ht-1(HC+7r86Dg}E)@ zok7R~BS?@;^k#y#xdoMBXr&2eUzl^4XvPsjK#T-bK?dd&{q)}<1G)eeIYlexBcRJX z-)8YCOEkT6yqGZl7GMBDoRg4OE9*6`3HD^~%c-j;n1gsPEMie4ZB!7Sks>bK&jII* zx>%&4tVF)SX=u)asHpB?1Ena$40JW-jAVT9__z7HX+Xb;HKvG$^lcM*76WT)`wwpn zUS;^dq`{jkA*TCyc~yd>a44eDH{_Bk)FMDDs)4sE<^Ik3QiTncsGJ5~BVSjOjIwk) zxe?|DuCn%}%VG<-$Z{A=yb*y7k=l1GX+-R=l0T=tVvKNexPghPhN7R;CF8?h-^?!~ zbG)U83~#Vxl-RxA3=Jh13=GekMjG%9|F^?*Fk8h=`{R-zK|r{&$SLXeO>SDxkX-?n z1b=c+Hr_bXBrD3d-#PLL3u`hnH#9T=GO5Eu3x}p1xDis!16%^{nQ%(=HvItf?e5uN zvjpHl*exvcxXaFis@}(y?=hk$*WdI1wEzd>b?ttmt)U1aD&m`m3xaTAZh=cXU{EJ9 zKj?cQQSMaK)ZmoBSD-dk51$?U+^iRq!4hFcMhKq3@m}njQ0GMi%Ah5Zymq)Z>W1C- z!NNi^U%_7Ex^YCi09GY901>kBWXk?avBMS|!glvQ`zvG*k!Pz{#fn_Ux}GHq!KN6M z-0Q{eTy!_-Z>#ILN5jHkFYRF8+tY>fq=sXZ4aLJoa}NH;u3q--u<1A|9Cc= znbgb*Q(f_sg#hIK_213jykKBBc#wwQ-wOdtJv`m35xy`vNjQu=ag$?o0uN zZ-=w6?#gC4fkuw#=$Kcd!F7TC4Iz}1r;Ynp>7WgEWQ zaql-me3kFd9aP6Fzj)OZtGuiGwLVj4u*#2G^H^6~w{Xd#(1u*o!^>Yr3zw#RW0w3E zQYchZqKir$a-*(qIamK0BcOKN%Y!w;t}8Vi&K?DSiw9;3b~%>Re_PS>8h20Y)b`Nq zlsq5ws6}@dDoQZmwzrP)DfTH&@72Ss~m z+$mB0bL{j92Ra0dU}TNxOVCHm?wmd$1I zE1!vvOFL3&woXwgo{4p{{KLbe0tZZ04I}S%{E?*_DgV)JpmRR_hU!ggtFnHUQ|~9J zbdrNLsC!>HYiZLO@Mnt4%W4^Z5o!+YI?`NOy@0Iwpg9oz{jC=(#uD#H9PG_tA1XehFN5El61#4f&Ll)R_B;3gP&9%W(ytX^8t+C$J-eV@ce|Ba@trL&$xu_rm>C1e)yf8r3 zOn9ZWIsM0i)6&)>jSEjR!T|M4c3Mjtd&=x2B3TGOC`d$C5l%cRDk^N9c0(zvKitdn zE8MdeFJDnfGowGbM^*av9Le$1Q)H^ZfRobsb@^~}Q_}+YWgoE3JyIuE-PA4~;w+Uc z9OhAyU&y+tVKVy6f#hTmB@G1FM5BSbzxU@Qx zsOdYafRnv>19Eim0E8ne8|B3ywyRTVAN;ET;Z2g@$aKlYwU=?tnaXG>dTSjMxVdw=wQl~(qf=A+unA%hhtjxfyFblw#$ND!#kbMUS zM{q;IgYesT?`|K$leTxTX7gqzgzSpaikcX`w{_o3jFjx)b1@YVjyoub0;!Bn$1MA; zFv@4CK-T90e8GW$^Y0XnyEOxjNsRG>9LX)Sa0u59ZrZr9 zY+$ria@EqL8;{1h>IdEWlvL6@J*};+KNY|?qONM&rn21Y&TL<{^JO{AJkA5I5PIhO z<9)4G1uyeH9NeH_vFbOguUMdcsI#lf$im{MS=**i4^FGtqXwdm%I2&n4I>029w;Ew zw=5Qp^uGyowzbq^)`|=cR6VZE zcA{RZ30ufYw^Om$?ZwA!PGet6b8Ux95x9dqA`{7KpQ{F!zvcW{C=G5Ne!(-PdJfr_ zJU~+slGBR(dXgp$)_ympY0#4-91qCPd8m>poi?mrA5t2Cp#KEhQrx>)F25lf+YOaOrJi4h0kd{#!KV-YZVL~R5~y9Lg)gcnLXwc!XUth z%Cu|OQM72#%S}{tsp`a@0IV4|6br-m?n=PpC4nGLu<0~ z)3cX8@1Kkdb~aqjR9xAl?++Q7J9447tvH}H4sypOw2JK{isI=2TTGad{6wQm_!%sx z>x`Vy(1)hETd5hAuI1`)Rda~A!`=sccfHL^Cu1$cz<3-QVKG7vgN$bV8`XSW7rJ09lHu5EPufxj8TuLUjlzr?e^q=2)oTeohb z?F}e__j~>{#*2YmfEQgNC7;D=MxK2a@BD=T_5?EWbMj znAh@nVKcxNQT6YqCv53x(Oz@ZbF2~_ugo>{gV)bPT2ACu;)8eyt0Gx==9-@x7 zO-+o37S1qiz&+VSw)M{3xG92=4=NZpit7J4PsADk_>ROnz}vHDdEo&H7n$7kb8e0T zEYwEAIZ#!52nw#wKmJ|Yxjzs`4;FuLMYvjs;}z3C`QD%*2z~V!m@w4#xOjN>>ko6c-f)EM~k@ zGc^95^5pB0=lL9xl5?7oksJ3EgtyuHJhVYroc8+Zb5Td}B>RxH3~N`>BhZ{$(oUR((o9P(yH^!Qj_JLV=6937<>$ z&~m0{p4zZ$)3jOsx9Uezi^+{5`(=v^8}p2#WJu_}`R1Yng|dUa)L)H7=~tr_0I9);E-SdX# z*%BMVuktX*$~e%j|#{7Owk{7;c<<+a(-%77=OInHH&ZTJMEg6k+fJzFbtITQ6F zUT@yg$xUfsz#-`-sc$=)mi6{pm)}Zq;^qgG>u5w;9))f`m?VB_pioCQ!IA%PvzlG% zjReW92w$A|*r2_5r!GZ8@3otg)0B?PK`E(1^xCHgOdy~SiB^DbFiEX(@7p(BZMfL` zNgibb7>Cp(Ke!x#HqgT0N$0!tH$!KPa?XJMCdyq(v@nD6r z5Wf^(wu4CI#JROy%W>(`#33U>8v{&Tj}LRZj8 z^h=;++p^Nd%5MyNNbxB;pKju3w-uc_0!usCCL~<|eT5)uW_G3gW96cMRj$Hptbv6^ z6LHkzsKlV4@^e7(U4`hbLbs=GYjGiOV=vAq%2(jKAfeoF$Xq@;@#%*L!i-CEOkYNL zJ$v1>C_Q2@J9$Rb$6Z=UC(p+C$}50|)&_>dUCy?%RfbF1kKaoM7D;@4nplKxu8qm9~aYZ%d_}4=& z*8SYAZ>x`8{5wY+-}vqZp%&O40=?3CTvvom!tWkN@Yy!TfOko!5WR=Dl=tY|g$w&I zq2YIJQqr&Nde(6JU)AEVvv>Y|?V*}}aF0m7MSm+QnGYF18(ia3sq7oc=*saQ>l}$f zry)imL>-(b9M!xpNqQr~ZjCepf77%!Qud)RqC;h0d9p+y1Wc?sN=!$O$CvLwL(9exJ_H?>i%kvM-G6r0|%S5j#56 z0L5UCVq%#{4jO?94;pCdLZB&99AWpH@jwDu9}I?%l)AXMJkg946cn6+kqWZeH^5|o zHLGT}x233VaDDMZ5F9j=fCCT=+ErVdn{N^hN(b}_2H@qw;B7FQ)^e-}9LGnFDJQ^B zsdUU48tdgh&#TdAT!>N|U_1OZC-obMhJ$Pc$xn!I(&P76c4&MY9W%Oh=EpxXC4G1U zj(E42So*VPv83n_mFuMLSzH|b0qfwZw736m@;4Tk9uR7x=axiC5{5PsMhA+#JY1jJ z;TuAv&8#Y*)x`5015**&Zru7fD{YR!@dW546MurKh*$s#_JyqIL->FTE%I>l!7G5G zDZV{>^4?J;{0f7om1jERneF8&N=izSCyMiBOSZXF<>~ToWgMc_dDUI0J;2OGTx5j( z7ZbCG9{Gt?xgiDnp1OhafMRQ)C2ay`zy6oj9v7n}tu%CfSHLt#m;5sK!L@JjEt3f! z=qEtE?i`sTu`GLIh&A8{rw@@3+}(NT6;f}+m+y7}6Hd>_m;u&@YL-1g{{S7g;!8e$ z{Qil^gZJ?7f|)o3XB{wbc*d-m<-$1SIj+^#AWe=)(INUwWtaw!ineluU8+N{f!*-c zuzMR$6@IN(E(i7lZvys=NINl*JVekv_~PA%8Vo>IFTEo$coU!5Ydf9;2R==@VBr-O z6+OE?C^|NlS1U5SOfp^<-35^_2cx+zVT~6Ahjyh1n2(Z?$_Tig<*DTJDf_nPLE+b=lkoM^0+AIKEXJ7@qxbE;?>WNaeD^tvemt7?S7e9#!h` z@;RK8b(PfOZXEBMbnK@<(kEj?1mo%V-^zlcyS5xPEt~vC0qCQD`t45H?~eG*vWeqn z<$0Y~J2UPa*1DP_63-$!jzDSt-CB$J&{4Lvg<8eF6B(L*`X}amd5rT?jgG>8@wh>= zDIL{V9tE2t#?i428eBrhZ2tUig_P{}OF8^332%s|_m8Ox?g+gNF8hnE}HJ1-8^QS$RU?cV!0 zVm1viDb@j-yvGH8-cxxT8;#7J1z4(FKhgv7GW$!XlB2CoWK%m~;vC z3fe|~`gJ*KTc*y)vd_V=FoikaTX!R;{`J4%vIe^v5-Q2KT`(<%VpsP$?bzznSIs{= z4?tj=s_4MH{Q<4aE|IcJ!@oqY%Z1)KVVTlaQ#9nNt;2kN_inA5U3}Y^(o!-Sj;wB? zx0o27H%~J;&aJ=&Guyx~=>v}o(wn>Ht$qsSnpE+(b;|lHoPW{r&iGf~v@cI-Wo(x4 z*2r}%ERsi$+En|6-D+jv*nExV$3hqD!eWr{N^b1{g~`7!qel1nG}1hN>}t0VtZ|8y z7~6X-@#iSjH7VYrl&cbJdy%V8kR3Wj98#y@(;6JjiBkh7L?{HvUvMyz*c|hn_QXbl z|J4W;h7JyHnU>s}Y8Acj!L|NqOWOcH2$b8P8NDH&S~*bi(@_FgJPrYOQb`O$ICM4s z@5ZULA|WJG`Je!629$(km`V4fD7sMGGlB8!pMWCTl}HGsOyH^mGdvGAhT^<+0t7jDK4&Ww`M(k%K17MumVn?^2zsUMgo zN+__LQ1NWJ*oldz)VydHcv2kZP{rAOwahaGzK%u(mD!az<7miKAhip9B=3ibE7(QN z5*09TY7<-X&Yijj1`%9al;(>eHRGjB2|&j@KJgK=#cw!AAt}O6ojNn4@w;(!79pYR zh^~MRu^wRz^-Ra~_4Sn?Ic=4$0x{C_kQI~_<_;L^$IOyG2({29Jb19I#3g-F4w8Se zrlWJ6)bVlEJcRb}jwN$n&kw3Q=u7Se#&@GDpmgZh%kH2t4d z;KIO$0Q?8)djt=p!mdIFWSdIg3&gL0=Lxc%RWMuXWg0D%_(Sjo^WoRYLs*-w#N(#o zVBndyRK>Wu{E|<}=QOS=p3gqgrmB`J8*i1WPh`ThcOTUJ!o*3K(>d>LfD6IFVc4Yh zT5ee7nhDqtJWENLMxv~VvAXSUI6V-;yWF?7!bJ=sPEy54lq1RH0f1ahJ3r$4CB_Cw zno5v$jL^1%?&r>HU&l(Oz2yiK1l=I$UvTG0xUw2Qwv~n&blM<&-~yBwS}M*G0s`t1 zZ{R@^%E_c0*t%{Ct%Zo`szcfpKGve+1^M-2_+sr^Y6?mop{WHkJqErymjf5U@S^WQ zNcVlhk=ekKo{0cOv*Ak}r@#5zK=O-F1|r}kJSYWFk%J^fF8QtAW!%0HlTF9tIe%+k zB}ni`fvY}vC#UfjAVUQW%z}zMpo%YESj<&oR>fdLW~Rfx%c#@Z-ae2%OHKf+Q-2d| z;RyvzJY6-80|GX=KWCr+Bca!?k){Clis&W~kv8@zZwESN{UJGfV)^_oyro#d;2NX_ zy4S$6^k(3E|Lh{sQ-ohy0*FK?a;_fTFsT11)xu!kfv^1&- z{ly}RpA4UEXP7)TzmoBcGi^lz}2qQ_*RuD-sY^3Q`4 zqL-1sKt$5`N~`~5`*i@3(?@M3q8TlUE_`);?yd2zC$Gu#uHTQqrDL# zV^O5vsc^2U0hRC)~?#3#t4u9rg*mf){7f;V=%o8w=UU*Kb9MciXy{H4WbHZ zuIZ|m!{=O?s%@PZg}mA$qAip+t*2+QvQW^*ycoUXP+S(vS~9<|Nr*_P_P1db{;tl` zZOi*zpFde%`8ety*UhnKve`be3<@&y8x0z4itp?5o{LDTEH2uN$bSz0{JnA=qxEkg z{mA>yp8nvR(vHY`zjW?=YWO?}b~73?{(UbAM2f~k79mvc&K6jP z+?qNmvGDv@=C#|`6`GnW0qhi7IQ*HoxT)k!ftlTyJ_;l4%B+pQ*A7LjC~@ATrlEtv zz@MV6jdAqD(($Fl7``KNdHe$p%$EJ+;#R*s3GB^pUg6gYt=2Y;jTkt%BRhL~x+yoz z*w|Zrt)e$PJUQqhCw=$1xaxw%dDqI$`db5nRMY}=Tz{rk{vHtD{wy}tS^lp5wCB%m zqrXLw>kd?ivF#I#x-rJKM?8)`QLp@2UH(G#aVlnt#-RvH${iYM^=^SXYJyg2{1w*o z_dIP>zP#dq#Ewn96=%HrS1GPu6EJk$9UP9KSw zJ;s!i9I)Up?c}$mQsn>lr;}-0e#+C9=FcVX!csb?g+$k^eQTvp_dTTz)1bVww&I3k zGyX@L^)cP2+tapz=RD-uG_+y)I92k?S16e0<5jt4v-iSW)`r{auYScK_h7E6m5ELG z!q2C1k&(?wZVIa%^Y4F^yD#me8;eijlQdL4m)BAeSmZty9yN0N&G;KCbpLJRDo9zQ zkU#yQtp3uk`S9p09&wG^y+2fQO}YcJCs?R2g=q}FT2Ej5j^aX?hWTVvlm|0wg<8V3 zt54IPb)~im-+6E~EKKhIYXLF@SAW^)IcFJ)px^wqnHy2KXr3c2M;O;YB&w`+!A*+A6z<`AK~wn}=TZyorbyS!z&?zB z6}sPBNY@e1xZXs8jsX;&Mj@~WI2y2rvT&B{++4`o6JZUfgqn>V&oYZ$#SqRyanlp6 zQ{D#X4HihuEVd3MFH$bbCh{7Lo|hrfMtFf3*Jf&{>8FtW9_ojql@}m5)EFQ%Fc2h{ zl^r)`KR_>?y^2WP7!-4v2hsiXcLfI1c7)3Tyw6U^cF4|(HF4Ya?SetESe`GmR!4Sp z%qwgN0?Tq83?SU{Uo$!)#G0wX_b0rAJ@h^*Rw?y!N{p|Wc(Rg+aqKHzq?z zwgiMOgpo5q=W7e>^c{BpFCoDlpH%+aU?)xy zHgrntW23c`$A4uR$~%8-9hxSCGLS-OY&Jl#&nnmet``vqM;DA>;}vd=2anA3|9~@iP#2#H9ak1T-?L%uY(i0{}%MtmfNu z5RA1}`dtFScnwGvkhj$03V=*dpB_7w2}ccTXNV{V;3+%`km^H|UWb{A4A4MC>=NBL z-wjK3Oef#?G6|Ovp4k!i@4v;~0d@*Q9mMjaHv>vC5Z6E}m#b_cotjsU*P~=G@1lupp44bXHW4suLMQ0jr0wzvQ@ouVQW4c^FxofpSz^3pM zu?dpf9)Re5DA&Me*%f?28v(@r;7lRZhj9sI+Efcs4xf*WXr9|(<(Q9+5cA6-iDZI< zmQ+o47oG&*#l*;XRV0^Gjbo+(z%&Tx*6FSDrRf_ehMMCtkYcac{D_f<2#TAYIT^l<#Jp$rJ?gnC-rc^rro)xq|dn>fO5 z;I)xxY8;DbEg2$rH@;0~E-ZMYsY3#-nfzLVP;rqMU>rfyFQI^=7v_fI{7rEBSwO+a z+?K-Vs~+n>uvhC#y#n{yf67P#6V=UFyI1-5ZhMa0>{{RG!%V%T(>H#E?XHoarA?LE zR_|oGO+BrMsm)33Ldi&+eOh0=bf4K*sjhLygXTx=&ewf8n(9Qc>9|r&5HtH1THRAL z8GrwDTzo%yZ~xy3j~l52XCHm|6=SyU(Z%f$n0VnBy{+O_#V2Z=-(pQZr%gy2{QZd-#@aaF92&L zVZiCerJu>&4J{z{oiZzZzX8);~^PlD~>IF`t-yDL8JtJ^=Lmz5bPP3&%M3xq`R!H zy&0TPwByK`7uu(P30b*o{fi%5GNCAXHfP&e*aAzF1qc}rHN{PI$;GB`}GpYhC5^j zC8&4ABK^Hr^KuHwAUN=O9wRt9TtaOACC>niim-T6dMk`swxHcI|~raZew8)nB>4B!iO3{98ZOa@y5cv@TIG&0kTvA z@DNK3`L^v1e3ANq;~0*dG_-#qq@iZB4CEe3a!mBD=x8Fv_3N%FW5-3XVM<=!(MLsa zGCEklCh}Lz9wp5bl(}@bP9A@qv-D4wBY~Ee_vnKMVnepNcO1H1+?cli-QUydxof+) zN8xuvR?d4E;A@{vP^PBwhi^yxFW+&Wrxn5v+lq0B3hC%FD#eUa;v^gKXAJ%eTapq+7@ZPoM3c_*DJ(?~b|oc(o(D%lMd;c+wNR<#MEIUaMv6?ruF_k#~AIAnKdV zwUfWrf={rWAGq!%Uj0=|+n1v%@(M{TlgYj37}vdf@0Q}zlS1ZF z2X;tFSTKZR>B%9pQ}9>U(szG$xd8FFSBm~^uOA*dK*!nPGbzti zdHF_*#i3otJ!3tb&w3gfGLKmGC$FZzk8p2pbnH5@>Mzuu@#VZsax-0-F(3NuKKQ;&O1{F^?1ERpBu-Xmc@r7S@~3z;sw& zxOZJ(&55U#<|ZjQvYJP3-kzh4(TXyDX(Qv^n!3vI)hn*qo%zq72mC_0{sWS%msbQ= zE`OTbZaZk-#w$d3v+4hF^&a3{x8d7($x4W*kTOGCLRLnZWtU1u_TC{|QFb9sL`IZE zMOjHAv+O7%J0#g5;eX!G|MxqN_kEA!8OQNF`TBmwec#u0UgzLi{i>p;r~UJ-(uRyj zS#?*Aoa8y!aI3Lyh^6~(dzq!xSC^8lKYzMsWgWM$%Fg?uxC2})$Nu8u)dH!PS{o7uk%t^#?AQ#^1172%`>yI(v+PZZlbM-mT-u0H`CKYC29rjDVi79wU13qkf9==JtBL6Ze9PS z8I`}kKk@WMeT@2&grksztkTklC?Q;dzQIOTkZdV*g|Vz==sgEW!eK8m^?@^4=;1>&+LLJjS41s;Cb!_ zl@l3SwKaLDYfLgLA3tWsFRsO}A(9HzgX;+CxP+WG zWWgz4zI+t_#Z7}GJaD6!g$2ElR3IhIzbR+NhFvTntmQP%Z8^|nL84!yp8EZ}aZygy z$B#IK%T2Vm?cA;tjl;3UrMaVnNY2qmktsIZ0CI|d?tQoP*_kjfQpcsv8F_3)m%0?w zwvu<<#vXJ%3_EWT3M|rpHKmY#NLU#qya1PT4lOi8xLPP)zgKF>wc8VonG>d0U zvarE1-i=>}f}Z$hA?9(sExa3Kic3mVwCMuQ5`j26VmSN~%{|g_?jQ3Nhot3% zkkPdx^F*0UwiD#^Uiv2AR`0oH72<7GT3Sk4beQ2iBYttK zs{wx|dk^L1A`@OTpC(uVy>NPkZj%`lP-FWc4Sxi19l3{z#C9()c_qp^$vXs~@Gnnk zva2CD1MmqlLO%inC$40(KR#nFe)EypChA~f+XtV5F?R zf<5646C?dnA2B>mVoO2F+L}o`oYv;%s~`qL@3zfy^Y!EE@uF(FN5kVcpi^+m-6pk# zmXfr8c*!ME91~h4*UG3$n$ajeba*8BSguOGfxCF8mc3AMk$MU44Eu*2Rs8i?9)SRhCvwE?`rkARRx> zp+=cmdSsQ{9*?B`dJ+oytY(#g7QU}?>S+%+pGq>U1K;2ZiRU1iXU(;gdMT7Eujt7y^c-*G4%YYW*kQFs4$c6xZY!XyNv8lqr_9*8ZZ z9PAOx1T?OTKaQ5w;@rYql8`J&r;6mP&8`je3kzVN9GZMURu%Ex8$RX=4bzD#wWwSP z-+TA&!ER6qiysFV8V9wO*PoN+lA)D@@<-Bph99I2WY9CyUB3@{$>w8oOo;eNPfH8# zF!p6=6(vf{wsT54qrZNksi9I(Ku8zS^Me4>kK;Y1q$J_-se5D`#LgtfajkXMw|99L zo&-5-VIlTzp>&d=HUvHf1Ox)oXQ4>_@JYh#!%4qTEcpO%=q}-9Lif}xN)ju{kQY0c zhQ0khu!PcZ6a&8mUrQFFLZAk*%fA3bJju$srTGkR5tmU1AvUjr1M^|x(M?>+Fl4kq zyGW?wc$E=+q^hA|3C^wi<;#~Ar%YLu`f(Wdp^{M6VI}q&a7V=EOH{{LVGiP5Vd4ia zzDp!r8mzp4zpcr-XU{6bwE-i-xj%nyB_z~i10mL_Sc7;5&VDpuDGBcR35rS&HZXVB z(Wd#Eo2x2kFU-MyZ?rIxPEd=(-ID{Q8fCDyfxYm_M+4|3m{`(Yjhbz@eR{X0`s1E% zpU6#@VyYDs%7xX9{l8Qv`3G=%B785nrG*>b7a>2dQEpA{<=?HpTk@Mo%S6xk$f7H% zcwwFUN2+Uoee>jGtY2qfn0?s8?%G@r{S&xnNR?e|)#5g3`eNDiUA9H!Ym-S!tGQ&r zTjSuLSxf>y*tdQVr;l_R`}wSAX!zTHl^|QtGRz7Ls9T@qxVw%$zIm0QWhkvjLELHX zl}&8qx4en36Ls>}9`lZK?>Dru_R(R@vYWr?UDOa#P&zl4nACi>qkOQr(BbJ;9frFK zYgIc8@B6vg=)Q;w#T3>%Ff#e>w&0cJVD*!_wRO5hWA)655-F{{&RdwD#you9AAWB1 zXV5WU3Q;k3>meDs>)e9XRUbF)tgJSmJ{_&npRifw5xrezmQem{_H2X6!9Ioh_eIvu zF#~4OyLWpYEUHvPGJV9b2^c%0qr%@x1%*@F*RR_kY8yNmC!P=_g_!VTqhp7Qeu-?~ z^h26Dz&X{HPED=-ZAl-+4oOKJp}jU)A7c4-@Zp9>A7Dbbuy$?ZBK+qR`An;#-Tr5yfyVZgg%(Y7aGzvH{bQEzbjcZ zr@Q<d~JwFFow91o=rzN^+dKD{)^&Kkf0y_hV;N)pXP}^fq4an~LSWqsOBb zNS&J@OvF^XIb_3-$}U$vzGRZXctPhknk=?%Z>Vj zw%YnzkL!Fo67i?tXXASsuEips2a)-r*LOI(x3*a{by;uONo8xxariKjm8Lq%{<=y& zA8`;5o{*+Mqe62!)2_0z*w#VVr38j!d^`$|U)DZ)BhJ74ZQ^XvZ>gg%U+&d^(EeCX zP-Sp!*^0k&o#U6T9i@p9hg)i`3o2^ksH>wx?;q<5sbH**-obf{qmuLN=8AspCr@Zy z`{ctm9m;oYR=R8cGOwP0?@69x8oT$NOnD_zSsN$C7%CE@W<3~hZ>Xp!6RM}Ud$~|_kkDI{v#upQ#B_@fecXC3=GLv_fOr6qg96J;PYqMj4e+tQbMF!Aji|5;dP)kQ_jR--NJk^(#L_jb zgeB1g9uIr!i$wHFrUbHTtYup+1;k;B2?8NpW{kY#;kG(sw!PIzklOt%z2Fg~)%kGQXZUMMXusCv@V=_+mS756Ge1_Bo zw&dinp$VzSAcX9#leE400$ke+Tm=0$vhlUQ-$bY9u})}jmVNA zWo1bf;9$tGZTa@Cmf>jk6$pcnLVzRjobL9h_Oh%j;v$A^8NB8JxZ1wVUW3%U+(E7RUnK@kIe5Fow+LQHm`?l`ONIsz zlqGmVIQx$qAd&2ZF)Q&t?V7_cIG&N73uXlIJi;T!K%$aV#qfs$ry`@Mi2Vg8WAGkx z>H?jAmoHxAA~7uh0@3tg41}oytP5@es3Vk^ix+#DMn>^Y?72~o)_D5bLkCo(hL?Lt z?pidO9(3dUugme6kcya@ZnKD4fUS*9%=YIn3`!j@j=IYtYDSToQ3w6`?ln%xsBk~x zGp6t4`dvqRmXfs>X9pbg;=%7l>(smQ`!c>jwJ^QDfeLFHbpIJqh5;murr2X25!<0& z#TW(sSruQj46nTmZG3Ltfz^HA5!SN^;~=52FL(_fyIyulz(XL0S*SOLuC>>O>$XKL zz-?YN0Ggj!^4lRCLzp}r+^$KkAP!_Wwn5_O1?~^A7E`v+M>sub5LK{AsQL7XxV*tQ z4=`tyXccTiF|#C-Q9!cxIYl8LpER|%hh$}ug^0Y4d)e9hNyml)K63k^446-V1D61T zcHg{DZA}d~qG9B@-@`}F3pzq&qb|i|wEX`?aiF(Dcn3EKBE&+R^eR_|g$zN9h0`E; zdC2JUkpI}JCs`!(84DBOss17BMYe}_*_CT}ODu2ycq*L4$QHt5Z-T&t$W#y5HPHsa>)`-H4XS-W-!&I> zF*P}J=1fiLJdE)yQ7Vy~k!QK@;sFY1CQ2y4brp+peKvgoHX`E5C(LH?{jIklk~`6I zoD}RQzZ#uMRE90#naHip%HiXHx-^O2KwXB3>nDf_F^l@Kw%WT^4?N+pMqGEdA(~9Q zpDEedQSI&RJW?dIBgd$O+#mIh!GDEvjz-ci3(jq<8w@348UpxxTh<4WUKb) zT^BxxpSe5fHzJL<$K{*lz3H~oS`4)ZnXV^I77j8(&9UA>XR%zH_1TY3!bbuPOA_uX zxJvVxp<4XR6#O9)z;8~%R|SWk(A+_{BV-wL5RkHxvX(qI6yf)7ZbVk4r>8jdb?}%H zhMaHgNdQD|OJiDrbzEzUD%eUwMiF*!Xa9wyxVYfP5xo36e|5f&#CXZa$H&Ireuv0W z_11I*cy2maDBq%g=k8ryeSI~ch>;h1hg$5&%K-3N}TDj*JHB*I<&K(>%h|? zBCXBMh3UxeepH=zd%@X5!-iX3ICHtwd+2|;0CoGYA;sz}qqP|_=>5A5@u>iA2-h4m z@nEDY;-GmyI*;I8;-y0-V{p;B4$rS$*ba#w&cXzY_E6r!P{Im_!wI{UunSZKzB~WM zfxS2<&$d;7M6T+02TYycKo&V3Cuyp&bU(m2_Y_`2^g95{NWc;?s8>|HmM&vdD~?;1G8&KbmK zytJ&}_PcB?;CsNbBT3fUdAYkhRVk8R2^tVSyr{_B#Sm?Aa&6U_U@ve?hh-1(11AGx z$VUUFg1G=8ARPt72!BB~mZWn%nf4c`tC;YY}&f&aT(&5?ob%b+|m(^A?vIu0YT ziiBwbC7OccV=~@DhOhE&#IzZ-2D!DR#{V{j7Sbq+2-i$_UIN{VFeVw6?%Ku^2l+|= zNnZpBfkE&jH#Zae9BE^SFdRn%Ruw!i*)u9+V5nBz|W)PEA%-Y(s$3{nE#I|6Z^zEI$H2n0ea4U$H9mZ`S4C1WCmLCVG z54{#v`XJ!nSDu>Q-b-QrTKQ@cdp85M#T1u>f&ov)WI3mBfR5yR5kB+DlM#b$r?>Q{AU#FHe_U`jA z{M~v~oK`|-1f0aJ<+74W@yh{2->)NH@wW7spHd{TFsT* za=ZHOiP4;)N=?&q5LeP&sQl=^MNEdxs3+v#a=)_0`p?upHK|@A^e?`#n=;8f40!R4 zdym>~KJ(G)YqI?Od#@;MAD&e(Oq69B&W(vpF)i3E#yKvfSS{}!h1q{+r@F(N)7#~^ zG~Uq`9Cwe}`_w&A^VE0jhxPYi7|mE?GSg5Z`&M?}2L1yYyJf$N50!q=qc)J8nk}aZ z=$ngFGw`?A!jd-NX`rtbVk&8qni^t!@MDS5!TPD6F6Wi{*QQkVQ9OCk1VEaX;ftFa zY}5xUHm~qK25ata*K?AvdrfF%eKo*b7a{`O6zd*C@AQJywPdp7T1iu zbgq@HUsXioWBg>X<(182QTLrQHx=vostgrr>FaKG$yg)O2 zf4AU1hp&CR*+h%_C`yK>G<1iL&}Bn9nN~bIt+pdR4pDP+Y zJ}#NOUH0Fz=)r20mZN(j?i_~hiC3;y`3x(k3~OqUcY2;~>;91@gTOy6T_a{zFK;9! zYIP3Hf5xmc>8FD)O)sDSCREOCVQp)%xoOZNKY{YZ+2qjUF^Lwh zi!!Sp{j+JxS`k&hne~&YI(i(FPDQ1stXpkG#V?He2-Ai#j@iW5wmlUW5!t`w6lmo;$uxGF0nR;@7!_fuu=#?Paqk93~i$h%jCTg$L;L)maR-U6T2!@hXPErkH)0{#*zvE z4pk#$w>5OD-3{zg1Pak_RXCN&xdOR47z~mJ;$M|Bh5kV?37P}R)deWVe&jSo_J$dh ztBC!u0``QyFZTBBYrT`0QALnZ7^%aq9Y;X;tBBO5@827U6&3;ckwY?E#s!Zq!mcOM zcoC`I@p%#Z_WcKU3b6@&2TBJx>(2In4Go&`PJKKmFAEY$L-?< zz-1!`5~A3y0>DGff{GAsa)g0VpYNp*`O@yoUNDbHVMHPla&w(One`m^e5qWDBNXM) z3*(r;c%)9sg++`bI1?o02W&>!7$r&lHfwD0j+v4GiO7^kSJE-`PmyTm*ksNa zf+)7<&u^eX1Q(GE>QYinyT$6Q@-P#vg!%I+XXitpJ@TOQn7JkQbqsWokH`e|Lcn(s zDk>1{4{kSdk4=9Hk5)oo{BeD|aSZi&h4S5h58`{Q-UL?$L@sgIEt6Y?TKy96i5V={ z2>8Vj4si{)Cl?0JkNFuo#f7-|5(xq{2i@B@E4Q-|&Hg#t$ooDGm=x?^oA zd94=tHy=Q;D*}}I_}}NxpWlK5Pa>(HZ9JajOq%|s#=x7X$l<$zvY!+iSy}F$p2Xm5 z7WuY`Sc`Cw!xR@Js7q)h)aV#7*E(z-vVWTa`YW__5`K#>qLiZ}B1iz%M^TM;bB_;^ zU=iCZXIPau5-)-dg&s9@E-%9t=)p-z5;22r0bmVV$kiSvK>-1)n#A$(adSADo#0FZ zEhaghg|we*e{H1` z@^&*SQ!Jd1s7^_<;fPNgI25?F$-X_iA7axMwfi6l0^YuoqEAzt!vg1@1x1Ci;n}71 zoL+2$LJ@@H2WD!|efHZCsc0+7iy7HFNtzqjzgVbRy^xqZhPYk6b~!u;%$dNbl*5-m zVowaQs(+Q!>4K_p3MPNC5V;Z>UE>zs6K0ss$)K$u8JsQ>D3y7knMtx%dU$`-)OKco zQy_8CyvLsMz^^{v8&6gSV5SF4X02oMtYq*iEp5=Ubtg!{@M@fSe8@jsTZrG*53%lO zRKZA9J+|N_!^hc^`shtXe~fiZ^AL;w-XW-IZs{Ho6T3r<&apva2q=ZkjJOvU6x<`8 zd*C)8Vz-lrH6gW;%CURu*Bx{X4_0z)IWRre4S}Iv=G^@JH1?tga4M0!td!A$e)^Mh z;IaX`FjQoVJ(B@6B(0=|opwK_^aP6lvo#MJpComJndPs@OO;r^NSN)zrg(3j*I$;4 z%r0OiO@&H_D{FuNeQ57*Gy0VD9JGB^DRu0E|Nd4{2u*Hz{DcC!I&f_vkoYGP6Lug* zwA8_1xZ(B5ox<(g7+uY`;VpF;P@i{yzvSv74FDB6`~^9&8Ge7VaI;*@h5CfW5@XVi z)SL&YLZku4Y4i?9D834Bhozn#c{>3F{+r8`gZ9R_VFgbluZN(L*wWM0Lrh-;t_6@@ zl!Us}l;q+CsufI5$L=?tQ*sk`6T_azH-BR`i8$yaaunX8IHPl$hdqp|r&)~gh26losUPmFdOviPGWz84)QFsl=|QOwSE~^F`l;NQqzu!`A~|}h>woUp ze}QHCxpqy&pULcv#QNIcZnkS}XKvKSJx~2E{&}xvPq$g~x-T8$g@$R;@dA63 zyBsBH@~3%cXrhuoYfu|c3To!azpqoQYiPmuOzC{|`Xkc8M=I4Kp%;6|tkZA%R6Nyu ze<(2ZzqEr-b1Kr+IHspFckNbsS@pE&n|z0cd(VZo*%=25E4#{%S=Zcc9KYDwG)elL zsdhT5a)8sBzOAEHwZoCTz8|nNE z8mDb0n8MnHyQmp;wR3On z9`Tmb&*`t#2;Er?h^wk%di`3;(Q(_(uP*)~^!t|K<+uBw@$SB(w)RIQjMG}bSoXOD z+$wrx+s+}`ZOixF_CDRb;0do%^X(eZQD+|9elp;4+`sg&(>8q}3;!br^!SdPs@R#* zH(u~#d3umqHd$}VsAk0d!c4f_=6i=&gW}`0L@x01p@Y$*2$1mB<$C?KB*%d9_mIWD zO;6t$S6iK4tPVA+snh-XR`UJ6T<-_^4w51NWaFv?da^vWidtS=+VJA%J{`aQO`e{U zjMD?Bj0cXzBw1Amt9MG%+hw^I`_3`9+4dY0=rA7q@zc_Y^QP95Hp7EoFWODz36{LD zwSIFsi7s_rH!LTomRI0Yq4w^=dYw!&z(s&2H^CqQqH+NJaO(}^*|TTtwvy+^tA>i% z9c^u@#YuTrbH)ktHS7-Rvqa5I^@Od|cgpX@fT@7k3?-D)*yi{;efP435TqJ=DacE! zh#S0l;gd8g9W?#NeSC`3cXLb_fiNtxa{ix@fx!+qBs>P}01y(M!9eQ??;CMXb!N^P zwp&lLHlGOIV@L^BD~?eSk0K}#w>=;K{aZj>{H&B5S@i1oz64tet#GXpmC3Lrlkc;<+=&elfU0(L3n6O*82^{N zHXHaALyUWdZW}@20qRP)2{1=EZefw0ZUjgZG#ULh#eWa4G%=Lxz^_>mToq!aNzlWr zEHc)Hr@i>V2U(7P0cAs5;ikV0;50UV5N>uqJ*@cVCg?rsK&-aKYQg(1Gcy!37$E%Q zbI?T0e6V1n1`sQ0nLrs>QZq+ONr9VAQr5_D;Ko7tOeyNHqNWY$7{8-)KEyNzR;b{_ zKYaKQi2{0cY~c6+PQ)Ke$DKrKfOmiPvu9ur+{Q41;B-!Jra%F5D>K)pACFq}O;$wjU`si2-dzEti#^+Sb|m-rq-Akj@;y?4z~4T^asc;0KBL z{8p!LH+u-S|1L8bP7PTwjlrI`IJ^v45{7WV2-8YR4E{u)??r(N8WRp#8_I=(ut$j) zNP!;&uD6?^&G>!}-Ot#=aNewk zDMT9fVVJ0a!6M%(3f42EP;nmXF`Yu=D20hX!P<*1Kv9Y67L$;Bv$Y6sL>1aPcH3{p zEYaPF3}+w@0(A`YkR8k^zQs>ae#1+3&=I=@Yro{+tPxZyJKr;KJ2xIltItOPR)%aP-Pf0vgn&lCQAsz!syu!E2lEXKT^1s@*pYbX8q*C=CPx$t7wNu;K_2EIuRHu88|`6&tE!V zgVO3lzl$(pH_4>3$Pk@0lH$>5bIg>YgG0%GgJ*8<2ky19wdJx~0>2%qJS1>i2>prt z4e)KiT5g#gG|;f9uiIjOpiIu}ZuBt9%*oN*_E?NQ7;ir?<&zLhpidFWBZC5{vDu6| zh9x#$NPcN2&VaE@{+l;$V05v+V4}MYempa7KR3)2tf&bo50k>ns-N>-a*McLa&xn& z$`AXxZKb)po0#4~R+|Fo6tsK1%OsWqH$6^2#_SMIeOUdj+%u=;1NHC-3P;4qYUilkw5#&mgC0!{X`2S{nO=lF1p|QM}+-E71BLy%ujL;ck?P zm5e0@`ly<*@h&4HBe2(^An0o8={dkRcD;h{&>=0W_?uwSNrEJcy~aGiK$~YPD=WYQ zEFF;lHcKvq9r=pQd58@=DK6j>n){8{kCQRWyw9+tAX)>%%wk##_MGh%hrK3Q7h=k< z%6upnM{@$@51K~u=r|>iVQ57xhdn&7$gY~QjTUP1sZ@P#TK$czICsrxX^lwTceE1$t`<{12lu)Sv<# z?=VaQ#R%{2b25DZqwg`8f;m3HUZQv$AlasHo7O4<$+> zB>&;uz(j*nE+7%SI~4a{%L2>jPZ&qJ2Rv@H84I*fiZ18d6bZz0M`21+*$Y`7zA@o*1y;K&&^C)#|e z8lUeOdxWy|%HE4A-#b;~zt^U-lI~l?g2vp!V7VyGR*K|xmG{;2HC#O1G~E+#UkzI+ zpFNweA>p%5y@KOeGw<)lGw)2M+H{u8$FfI#6ExVver>aTHaK2jlQ8VSRub4~m8jVJ zEK5nva>aXK<~rt#vYr>~Km1~OZ!>Absc{GZydnu@X&FGThVtNi?Y-i;$ftXBu? z^X0o)+s_OA6LLFLSDF>_^j9)>2>96io4(cf;B1*eLt*ns>(d`qitVm18{#wOyH4MT zYm%1NoqAv%_0lE%29?s)@&)zj<NVs>Pi z-0^$ew>-hhO^4>1Te9js)PF^%&4)J9TpH=rcYn!J(qmT^e&@z{ecPEI1}j2!-zJg= z%Z2GNo3S+g-prUVuw@VLd8OOe!!E?rs<%SPgB1kFD;fKd>ZyglezNQj`RHkQ2mFMO zr^JZH&}YO*?F{?1nR)5xw&N*HF+oigJb(Qb*K6h1PA=4^Mhy5?EE`U_LeOuAc>=J+V3j9gdC=XQhZ7OPFM=g^kL?a!7 zMljt8^Y5pVALIpIIaY!@Ee954RC>|UbabF!-G}eTx1&5syp-zCtJ^s2b|Y&ob@JMg zb5qliy!TU?mij32f`S8OU@NRI{a*C>zzy^1qnq$85T+ASBkAQprD$+%b#pYYoDnTY z!OM2u^`R9!Bc)ILx#>?A#O>BLCxw1s7JL;Kvhdj? zOg2>O)oi7Y%w5f}!(0FI@!Ealko{K>yh(d=NZ6m9;|~0PBwkw6c~^JZSKYcK?D5HF z+1=6c0FZ*)?*Ag4u?r%|m0%XqKfOmBaQeXJ_us7C+7WCAeGpykLrwDJ({W~zFc0sy zrOHtqQ!5&Y8jAiDhc(ll51Sp-;2+=3%WD?K7;$*p;GVyl*#|Q7&W7cEJ>MAaq0{sj z5Ma@EM(03Wdv5OJZRImBY&@w|JR*0HFt`g7ms}a8|C38r^QRfIjtQXf=869 zgJ{{L^m@ZeuGZY-TJ^0EPQS{Q?V&0$TB#o}b?L*N!_3DbFzwT+hnUwae%tfJnpMsK z&m9MozdC=Eu1uT&b3f!y`HXj|anWkTu`{2Bn|~_|Bza`XUM@0pE9G=Pz(pThD(1(v zbwt?pEVGzhNN%-RkDOlez6joq)bhp%Nr}!<<&Jrs1>5Pw^9Bub0RrycZ>7r$Pb03P z=-`I5q(dR~g6(>(`A=RO9qhP2-fD7;>+Qklg-0 zK=R95#>*LFv^drm%ZlEp6Y=wt(s`1+TQP3+qI31g&ft!?S=a0a>tlZdA5ZqR-kbE> zm3idjgFngMA#J|NTN`MW+ji;u(~@+-Hpyz=&|ipof( z{niKczly5*%9X54SDr5`ndA%2%x)Yh@v{zU%+jdXJUudFkow|PRu)U~+nXV$l@1i3 zPyu*Myl-*atHIl^g0@(wwYSluH(Y$9O4kFa2t{!9{j7JC`^!GU3(rv{XeK0cP9koV zanLzy_yBm{j)|a9^HmZ9m3rqK%RV6*O`L*h;9wG8h&j|CLStyE0oH!&0@1)TY^*X(<*^GaFm*lU&b zqF=7*I6N6SQAV5CNFye(HZ0BJch>Q*XxR~onML)oOIagOwu>@>USF#^rS;4xRb!y6 zk=2^FcSNC)MrFeP?!=PJTu<+&hLM+lecSbwbhP#E6c!dT-<&<&e3K!d`hnj+3FZenC^+ zFQ#mJIvT!Q7NCn=S^thM_qz^^naiKzyF}--Mp4DQ=}|3pOLUOL4>B-ZA! zlAAz&yKHZh$JCPgoarC2%+uv#i(NIWRZF4TLWMg_2u03Y*g9x zhJbAtK%sJrT3iN4@$iTWn>Ayr=-Bu;xa0nnWf)@hyA{d4o_zEAbu(nlm^Y$;Os>29 z-vzE3PcL|h(+zf;t1`Q{QYKJ)FL?^=&Z2J?q^hfPk54%9Bx~NlsKBdH(5*w5URYiG zPWVbd_vHtVLSq;?lwL}7Rr}`JVjh2f&0nDL%Dk^eaQG*&X_={~Qa7(f)yelJDT`;5)_4xOW$;Xndee$qYZle~OL;rF2dDR-hA;~XCtb<#33ihlU6z!UO zP?u@2#v;)%v(dY1lX6l`{O%uZih<8|+)dh7C+H~E@mI>iX{y$5E&s>l2y^ns;>*QN z9;&BWYaTkNwC?{XM$=w;_iP-W3Pp0uUkB9Xhcy7G9R9t&MN|Wi<*NL;3Sn)%tq3^j zaY}9~KoiSKmtxzQX_kMKJ{UH=kT@u7V^V9!7p75Q4_VXY8}Lc`^EC%|7r_Hr5Zy(jcZ!f_wMcUHiqGc6=^!bn?iZb8lH<20h;}U0Qope{c{y(8p-x zh)EL$5nP$A-tCiI4CN+FWxbo}J%%JWZ(b7bV5s0`rNdxVo0TfzGp2CM2^`GVc7-j` zBqQw)I+CDG2TeNjZBMj|X|$54LqG_dq!k)?DvK|}I{<{Y>seVR=&eX(4)i1=-5br= zs6o>stW5Mvy6_st7|KPWD%u2ilRWp8RH3gEe_TDj3)53vy&_YvCnBhIjB`QKwP8nq zXWlmlmd;F?WEL%tH_aY9%X*{#8zg1enkTUcy5;EVd3><`6rRt=O4q-zJbctqe1~S{ zr8e<@L!4TA%(u`z z7!26Q`YK}HM}{swtLck5(^T}@r!ii({LbWCsZ1~ivTf?HZ|bya`7Y}NHpZhYI@_mo zFJ1RCGw)qlp?>&K0K)c=!vV=!tQi~2nebIBeOU(vHb6f;0d|JFFJxt*$>I|upN z&^h_IMdCbbhW*YnmU|o#XBXiYg51*Y718;)6O)Lc)Ffr)XxtH z`1_9-zcXoyk92kYDL<5{CMqVbFVRr=B!HCTRmUcYa% z^<%sz-y@;ES%X2b@4l&N&pkd!zc9WbcGSJ6W%e4S{X<%+l;=H5uqi;ekL5HUin` z;Hfc0Tr9!+MDAbmoqw4@w(%H#yLA{{J`8RhKq_1znU0;{x^^;CfI9y0%yCN#O!fPw9Bd+mmT_e~bA! z;3hzifvE~b1evlymq$4DfY}=tS5k5AM}qf(k$JPn!}hi0(gnu6RL#~^IX%JTR>`D> zqA`J^s&+AnN^GWEnOzU_@Qyx8y=m*NLSu>vm&~+QTO8UUF*~C zj^ithgzRu4Q4E+so|X1mE3=D8C8X*?*Fh4U0Qb=ip|D6w4qTN4&mOKpOUXuUYhN;M zZF+y`_t7)F?x4LOna9w8>&yHy?ib0VEBftdjBGK&7bOvGp!gEC0BFQz1HNrVhxYB; zbOkE0Zl}v7^Rqss>xlCx@HZD1VU)X|Z6dz>x~GC=S+5p2(2(?1R#(G{=VnDe2~&h@ zfK{o8H}=^U0RVAq^wcnV1nYaE;Thg}q(MRoegfqQVhN#wrBq0SyNJY6*0mA5TONqD98dE069U*m1&Gx@pMMX;yy?0>`YW6ZFHufsN z8l)Wq(_WpWD!2F@e?47xdo8DxBop2^V*k488KW!$VQ7PzOcS0lpS^eQlnyHw_vKy6Uy(oaDh++X}Uxnq1c(mdc(ZeGv@aSYKwpE;9$_3O?Doj<^nMykevpyJZ; zwl}|FB#E~Uj!EG469ZY);9niyj4Nxx)g2U~qM{<8f!Gi7rIuD!1AWpZA3aY1%#48m z6Yg@Y#qqnA*inc$0W1#ZJm=e%58!W&`cj7#OxBit?^KNJ$-H~x#^#1w_X=L0eEY9O zzXf;I$f&2^7%Ed1j9U;E11B0>4&(9{^v>|3243t`z^Aqa^8lC{t&gs5{cV^dRfNMG zHn1UxPKLe(6E8qTFz26m^O|3cpPwILHd>I-gP4Im;u4fHjJ;Ir_0|QKKl;26+>teJG? zKh&t|kR4V_`S~*>P;}$sg_53aDypjEzKRNcRpkOc^CAcaFI2hp=$&!KqxaPv<1%dY zJ3I7QsX`bJ$HkdL;^yUT%F?V;R$O<^VN(H(+a-GXPe-@dcQg-ae_ebYnP&}dsOK40 z8(SUKlOisIH~bn*3rr@=$t5Sj3j5@$!Cc_!llZJqyjV z_CLSBqhg#X@gVK1h$t>+-hK5~_E&r_l@$*hskCxjX3S)Pv17{!HQ)5KR9k6c>r1XEm5O7X-xgg zlQ~XbKDK-~C@%A#L(fl8XjWW(U;mgtUf|^8#A%Vj_fLbh|C4^Wn0?~*V^1!u8QtQR zx?u6iX5Yq*yvv)t^vk|$_}&lh*i@?~6|(tg+^sEaBLQ-4U-ody4h+^G>bb98_VE6c z56@v8m9lpNNQ;3xTnz7=PfFIwAC(*Bq&>Y{iUXv>=hpO2vod*5Y@ohB+WsYsTgY$e z4j#}y+O;9^BWzT5cKn)&Q)Z&nyqvHztgNO{l~DRoY(4HK!P)KE7gxYRx^X&1m z3ztfIIs=aanJH{`Y=o=5Qvq?c$0sbVE>G?e5?bVX9!IAQ+vGu7$F2-@Chs!W_>3`Aa3KLbzWfk_|gmesGFV1T3AjVG#g7kia|l%%M0$IRoWi zIBq&D7fxcj<+1)4r9x9x(S^)FN}v;FP< zUwBF+nLy&ifyD*p#5kLv=`V+eh?=^(!z{z8TZ&MI!GQby;Izv?WjN#2PW#umVbB0P zz_b<#_OKg18OiZ4@Ecfza3O@23yd5Hahl{G<5E|0XI`xIBFPMqQ72*R$D~Ap0E}FJ zJ&pCi^X}l4M*>A+W8-uFw ziyd;TF2R3=7{GuW3hFIJ?0B6!WHv+jg#i?d z%my+NoOn@1Bz(?b{v)@#$QQU&e^J8Y=WXSjFKuluxVYJp-WAP~KW?q#^Kj|7j$5G> zNm+RI83NG2@FjvxoTMzw%%FjhnH7w?Upagefo|P{Di!=f+~r{`CR~Z^V&4XFr|z#g zvq%#|5!ZqsXK2VYKyJi=@)oz0q@X%hJ=^gE^UE8dB-(feT>oPFW-$u2HnKO_r6Kr87Jtk-EbZ9;M~<};G; z2krHnWq_S|t;hNq7@d!PFO|j>a$%wfJ|bOi$0?*%LkK$zLF1Tq{Mp?_Xq2l6vDeE> z0$sK_p11tC4Xnz^3>GYk6!FUZlZWP2d6d{9fQ#GO-Q8_>aKWb(g&@%l zlY|nOzBP}z{7Loe!;j^qPLW(5(QuITP zoCy#pkKu*`h1d?_b)}8y41`qE#D*0c=Gl22Rx0dIu-hc-8Io7*h-dheFbuAO-3gg# zBfAkDTCvA?6~3ahgY%a&Bd?(mHEOQS4lYc(a_-maitM|8bvM2<#pK*dzhdoo3L~Ya z$B|e2Os?*fru_M{i{`Hd`-dQ^S9uH;RzGpD+SsOS?ROFTF{^MPJ}I2#z|fDMhdHRa zbJi7^iYTZ@-TZof=S9E%p}7aiSB<03nY5|03|2n~kJxf>mlZ27FZXVCUcTpk|H9)= zXTH7vIcVDf5wEcDjYl7ZhcU(QufMlXd*ti+*h(Ic$J~sO2n}%npXybnOM5!m* z@Ekc2TFW|`%%|dBXy{JMvBt5)&fJ zdNpgk{r2IJg7W@t>vNCF>nbVBJJ9nBMBG=W2weVfeR0Rz6hI8Wdv2&daM-;^ti&3tFomx=y%%j9@2_7<{B+aJs$o;;Dyl}IajHzK69e3UC+;^rcx^E zJSt+Rrw$9hdX4gW6VuhQTz4JA0hP5K4ROZKR}$F|f2ypB6=-So=uCvEG%t*dzuMbq z-xn_bl&)s=;lu_G9mde0a#&bXy!4C>bBPtRrx-42pLIH&cU(7Bb);cE#gkGhP1X(O zXXNFT`)BrDNTfR~8v~`b&<*VyCMhv)HuQ=H@Q*WrOj!e%4Ud*arR2Sa3yrvSAZ~cls zg@>Z2gnhJADU|<&PHb$Rb$V}`Tt|hvvV`l0>e|x*+7!$@mo*ERcRD&8jTkG4=T`%w zOvvg8D3x#>P%JTa__=->yFnjT1Z*5P&|bNLMuaek(D+mj_CrU@d+=b+=OZ}Mh$#w- zfI7FVY|;iA#&sPkoyb{2nliE9VV88{=eg&&7Y+Luj$GnjuNC;0IE4V#hI{f@Z<(BD zNoj@kzKB%sXzUX)$;rKT=Fg2=jgP^r6tm`zFDKq&1P}EHOag)``b{9*1e0&&^s&&v z!9n8EMH(V{Y6Lu<0t%t$`}>(WpS1K2bn@6Mg$^BpRAHLSa3)!Ui4w*$_8mFHi_?P_ z!dB4FYM^^nPMhAon?2yh(^n2wc||80Tb>;jqYr>=osx=5M_cvy4??f>=+PC_Um!vufZZtC?L1F-VLR0hXo4HV3BEGG;EHDr;DC!_qmgsOI z>y_-lFuOK?eF3yCjMs@O33~D)M_9^w8{tobe6&R6z0a8Af5vA~1VJ7?H{KDAt{rrF zNrc69=4ZB!@$cVHbStq538%nA64Xj^$ZMSOl7WBCG|nCp@^(O1ah|&Y3@50apfiQ& z6n!R+1$3Xv0RaIe%gSG11TLKL8F;q6fIPfYmpTGh&Lw&OV4pC6;clgaI7wL7)MUiV zL}UVqQXX_^%}=g622(Q>n~B;%hlZj(vY}`j0zA!N@k^}0=6?4W_$G1FV&+0rnZP)} zzFN|oNj)c`cFP9U7)VJ`yN(g701vj;-Cw_|!kHCv94pX@aG9{4hl}CXi-9X*jG<`^ zWIFJHB;SL8l`kWKwir`zvTnMIjsQwq*O8Me6EbTwZvk7DrwxMx;* z4=_8p7_T?(|kX}qV0J%63QVZp2v3mH|h`&jDwxsH0co+c4BSE}) z1u|1o)|VRRDpeW&9%_Y>2SVcq60W=m7ac?_>f<+W-h4j#T!;>P2KJvB_`eEU*0E~H zBJvcIgXk}Ij$$-Vhdv-!t@Bg(y?6*x(78bJE>1~;<7Y=2M@U5pbq5#|>ur2xk&Xw> zBJm4=Gjil-t5%DRsPa)rZ9<(yEQ_!+O>lD4Qh+UMim3)>W;U@(V5*WJlNT?9i69Ae z&Pi?U7|%wOKCmXeB+PLmQPw{m{5ZfY0g#QrKQNd+VyLxd+2Ijy^rCsD)uvZPj`#Rp7Au zM3!}%>XH6_3YD9x4&9?aK2U`ZvF50$olR4EE6a~`jLgmF1t_RN2){J&g)Cgq>M3T|uZPzrQgr21aoaNq!4g19$83YxNurY7&g51Xr7SdOzx zSE)PRP}?fS{7`pJ!(J7!D5t$umN)}4IdF>d%vu8)sj(PUJZC3xlr zlH4*=KcGk367bY<>EEh6>qZxyft&*`M`jf?ZziRXduYYYulVTk?y3C~ot{-m*rm0V~&AL1<@8CUscf-o~EL9D2u!2p2Ojc#z<4affeoW7y+Laq5L`i!pxPN3x zbCi2B+)ZuQFscw5Z5T_D;wO zm9j_3%vMPm$taadln_}dBZSH*dvArziZsZkLXwrN6ruj-=YF66a~#icyzg_n`Tf4% z&$zDhI>!oiR&w)_YDg5Nq)Zz#9fWRmYVUK4RCON`v#I1T` z!HfSRkrJy$01g0AVG=a-55)XNu%=FNJz?%<9YH<*)CPr zmgmje#R;|~A@L)-UZ$WY3}8AU!SH9h3Ie;Y(wS1|>9cb)Giw+bwV8!;p#2JF5P;nD z)y(e&n0An9I(oT&PWjF&@Q(~1-?DY9TU@eY9O2)gLxu92H*dmV z1IZe8Qf!%7S?vcmZrz|BWK#gEB&atrz5{wRG&Gb$wMTO1klAnRw_@&B4`k2VJ0}JP z1zovMt^1ZGjmi)>3i~Ebc6P7hkAF)@+F`tcAxFpZKcVGjR0MF8m5$f5Fq7{p`GB}7 zphKHCue`Xb=@yMo4yri*A%L(lj_{1Ag1B!#j9b56cuGD?KCIf^d@o&kc*fOLBT5i> zpThT@z$H-0;8Q{Ghs+I>d0@t)z}mJ>wL`!5nvKQ9Em@;A{hT(n-JO@oqzONYBB1oEjut+_zDT%2yiOZZR=2`)I)wcF;F21 zWOW3Rg?7O_#3a+u(h`XPB%$)W&)6AyQO)O@1d6uw^!5gegrbQlU6x0V3OL=Xc= z{2BCqta0?hnPRy(>db6788561{(FRK5KQo;f`5M`NqLT(jjLZH50RXajts*wJR^vS zQ$9IuzT{@1+LxT1EavsSr>jdE!UjN;lq9$m9wIImHy{{d3y1t|;`hX-erp>$duZ#Z zvrxt5<%81~o)Yx=uHr8?6rjv76(asYF#aOh@_!Dam}F*Rx+xM0 zSx$`5;n>ow9GikMNaERjk!#yRDqA#O2-1LfC;)hl2@HbJQ;;r^q$^@NM@%UJ!%p1E}CyV|Uv%NARKQnTKSv5PBmN#jyI9AdipWn}}IJn^eKloKMwC%sBxj zZTwH*Z&mJUm8f_TH8!#y#2;&grvjRU@QH-EQR_sJehKB_P6YeFEBPU?ac}`Yo_&Z~ z_iV+{i=s~F3<;`C43NmF4OscawYl~9I0EQL(2q)?lI6ZOt4f+6cxZQ|D|eP@#qgT^ zer^$MSd9Du?5x{bkA+WPdYZLD)c=rgp`0V*b972YNZ>W^fu(G@*DXOEjQC#Ox3bxUb0A6xaK1QCOXLv)P%cmcqStsNhXnJWZ!5XMJTB+$UahX(*=!$t~?TJyg(KFEQx6) zJH5jXx$lcIeD!{u6iSl!^{I|4Oo97lY_)5_&W}EczM+AtN82ZG(pgz+q+r1EwM;-@ zh~Ul%Bl|#qDjrn0YEe|SO&`*HK6EVpeBBY5kSf4Wm7PuH@Z+PE;jstWsEV!}t@-|W zM0IDwiGKr{mYL1HMOxo4b6sB<@B3S`F-MbEPKdYO)FNYcu}8v3fjMGvd3d9;s=QyO z$Ywr1Dn;+{Wk>2`lV68_t6L}BdBDk~CGfp(B(1o(OuVNoL0nxUpt+gB(d%cMK&$&w zVN&r0Hu=)Wf{K1O#^|N?xF<6tZZ$}nyZEP6uizk(@OsWgOSV#;TP&4T^iOvGUU^r} z<@mmbew?&P*S_!{RI~rOiTamiG6zy%Jo< ztMpmdDO1}yyP3YKAZGg3Pg%vwRl)Cb3QrdVU7GsD7EXIdEB>yM2P4mRjW`J}Mn(-a zHLB6}Da~hB_Tt27UM_OX3rq5$y&cl!@+lx_s`>rih=*xA<&M9K@5oC_@7xp@DR5gu zGmM!<*LCnmqphpHMrW6pUV_t?mmJ}i;zs7eC8rCLsvm9mGBQw5(v#CBW*GEI;>)7w z_x{o0u$Z&_4d@ZOb7Mn77$@4*4X;(-8QSL8Z=EK$@B2pYT=V}3k>(9GGJP^zA7pTBds_Z$F~N0|{{G_j4ihbJ}oT^2~ z>zYsCYuy_v7iOcR1wyG982IHAwbg4hOiDgZ&g?eWMBmfC-$GDOkk9o zcK@+5&(YQh_!0HXI5R6UGsi?*zdiC1H{0OZy$=pfWE9VzKGirCh-PkhnDU76Hnf<6 zVu#;w6Ku7d7+tmIul+J7seh_tDa^qKRqm6*0ErW8%&i z5R&8l-QAm9croto(!PG7D??+?LjS3}pf@{xWp!x|7w0S*{-Uc3jwmzwU1=fr2Z#ml zWC@%$CGbA{e07Wq2N;aO`Jv(@kkEq%Z`A1(o8Vk#&6Uc*&Q83Y38qhhnF>warAS$z zPGrCUhc6mm004gQ2k|3nL4-pg0uo->n~C00K_LN{5g?#m-G5!M(uPkdOh`VwEU}Jx zYow@+h|vt86FD#pgHb>G{YALsIIASnZa8k(7mITp)r%J33`)R4)GCSL55-%-;E2FS z!)y~8oF~wW;d_v)T(@7Vs}eX>!_4oC@Ph|vcVFrx6)-2v5zr=Rz|LCiy`)YQkSJks z@m!WU5#jE))a^$(A@Z|}Ni3bc$4sKI8Pf`t8T)k&*YT8zE+R*Lue#{bc&SRb# z@P3TA?Lk|(PIU+=8K>|f^He;`k~$x3z^fybDvGj*N008HGDR)*{c}wsF2X5^+O+Z{ zgNzl;yd~J{z&~*Z9vx}|(7fFdset6*AbH!n&e`d!zLeDCKoGGg$n3W($HGoNaUjMX zb{j;SjE`5>bt=j!9i%KNzArGYJa{rawhfj7PlF?>a} z=Cq)V9s{wqBo-Q90F3|{*OKkT UX|(tp9-&Bu!=S*xr2rNQWr&Z9 z68NVbFAWJkc23!*K44t9-T%tl&Zi-g3+vwz$pNU9iNo@CxucaNRg{qIyfuFd_Dp5==q2O}9W(grhn50TA6sL-q-PkFE~I z?9ESSPlJB}ejB#<7VtVKWnbOKP!*{Fs=yZ$msNz0O8%=!O{Id!4Ifa?*;%w${31rE z@`GB)AFTLzfiSqi+fw?sdhze-m@rj^~|3D_KLz&o&0Vv2} z@Y0RM{yx9(cJCDtvuvtZfTaM~$;9-Goi{8Qfg!-H3XDua=zd!-rTp4khsx5j=#*xa zTfYB?V)4g4Q_zQR9D>+}$h*PKz=|Q%Ls0#2rr;>oL+k}PuU_Z-#B3_NYEaQ47d5%NMv0gL7eKnTK(LEHn9OBzBs86CIxa0E^_OnMhU!vx?T z=3kK{p{uvI*ZgAiGXr{d*-gI0ZVL{VGYJ`MAb%Qw(usAAT*GLry7Q1<=NU})hJ}Qr zU}0MSxP>JOT3)9Xza*#*cv1YJ{8;QGnEz*hk%3|#^Q@wQX#!HiRQMv8F6I0xoke?X z>nL!3f!~3NI!?_4__2YQV~~Vkj1`H`i){_aw|>F{w$JtQE5@3}$R!o#=l92_II~{} zq&Yl)O6mcMG`Zm^pb;;N3Z&}4)NABm+ZD{}3+k;)dstk8?w4<)qbE!H2s;pc z(V>DKjNIRBmaJCM(%OKViR^Qn)*iM{^yhPT#UdFh@FUqGO!sf8i@4 zdN1ZMr{ZK|GChTnQb9wC0HfVY742XZt!Iz&2C9QoeY=GDTltZ0yzho;gDYL<{(Y9n zPEuytxxxN{+CYd`cFxU~iu;8Fn$FJ5KH{^-_v)~Zo{Aa^){gwVfz{We|Y%P*J6JbWRVt=bx0 zoegt#KDV{jKbAJS9HfrGTjftx8=Ki4^9{)*J<8uHt$J5P?PSWE@@b{-MSkA8p$@d0 zh5jsV{<>aAzy0Dj50|}*v8;jZkMlmteh%p4j0linJAQ4gEF*^bK`YJhFu%gX=I3kN z^5-s_m+dH`xt;pR)GxXi|n*g$Ga_ik@7|9$MVsvzMLv(F)h-bXG2F z<+AV|e+r>AnIQU#f79iqB49T*IB)fx*lD;7C_DJ}mgcF4`}mAac19;;3+&Q4WO~ma zXQUu6+GacN4Tl}sIYPRJg}(n%G2nX09ZI|AI`AzmQ~S%ut664W?>wg2utiGhLgn%) z?#C4SM?6G14~~zF7#?4@as6|#yBp`d*xB~Hy#0dz*zIrr*H@~KiTvDi`F8w$wuh`+ zFAn*}-CXbU=VS6%?9p6Vxm~%(M!$_4OHMxT9lQBPoNI4*$ul76fIS+6zDeg8qWL}P;_>DV)kq(7@8o~yp zj2vQx-W5ri+T$PhOPvnkFD-56RQNg)R$C>_Mzs(jSfA|Y_-ApC`*=i>R*)W=HGh8% zk5jhXJCxqMpQr!iSyTI&X2{D5HUqFKa6|O5> zByN26N`$r|*Lfcd92<0uDFwNgn^ixdYAcTb; zVuq2Y26Fpzgwn`wRA026BV0;YI=}VqK|C+At&rnGyx2>rwSbQzc9vKHCxeFsTWxnx z$zXezAvF0A(Mn0zf5oTort_no6Dz1z^YM9JPGdGOLt;l(us&^{^soxrBv@QA9xC*IgmIupcVvyfv;5t zR42}Zo)P4S5PmJDn|wSxi()29=rDoR9ZOD1v4I#yEc)?>k(o17y7}jyH6K_%T7jwn zL`6T4hs26+yC5PJb* z`Jw=uZo5Q`)KEe~DAE9?lw1@re+C#;5}@}W@o44Y%fo5Uv=s1syqKuLUe#OLU zz88KMm)bvBtXxK|br7~j0I4x9nDsnJ6dNd5ZQ{-uMhT*kLVWP~oWU-IEQPf=#&bvw zL&Zaaeo=N@h0Fv64m;fMULrbc(TjL$f>t^P-vI~>XfPRv!wqPr2f>$BQK8de*4#kP ztp^X0Uogpg5AY9gWB%oHASjYdDl9@!>_B@j{Y{+Rh*>G#l8cQeaB$eyp}r=?CbGSG zv0UJgf+h-|7sGu0cd9YcPw|i=9jwBVL*X__<~0Czc^BZE5ijqSAa8>PViPUx7r?2H zN!1}64@eJ~H`*bVRYF;)TGo=7A?*@Ls3jx`jQW?*j}k9_x#f5?1K`co5oLy!)-Cz% zXzP4d=MMdVXZ%)XvxVG)un@#20juL(TZ($wOwY{k4qU`!-NHg%&ZsauTesEiqFj3U zN||L{(c@cWK2741IXJ>V4jo${R-3S^X$H$LMRO^u8?G~u1AzB1jJOL7!q(laS3z5V z5l7*-WuL5Qe?ZqGa|B*TA(EzPhiVQ^MfV5S`9PWD)4y zmr`ALaL&nT=;VXb zR99ZsuFfAugcwwrt|j!FA%*y34{{{Clh`Kz*h53`(i4`IEj7F6|Lg~Gz+If`aS3?T z5TtBEE;eYCh({z8TF<#tyED|W1&H?-IxyhzfFUsW z=EqA1=Y~X>foz3g2iysP(3b_gL9ziDMdta(bjVhcKqeyK{ByNyrg(6^=d z{-L-0|6BldQdB>D7~bWq86}wa_L9Ow)_C_biD9XknvTbw*P+VTH9UHdf<5*F5o8LELki5b8mRgHB3uLdnWq5Obz)ZW|MjN=%|pZCjI zZ>&sLe8l7vMgIqM_OR{3B5eQjd1OusCPnC1_u-zIsr~o-n^BidA`cE?bSoq|0MABj z<8HZsvFvUsQDA|aAd(WqpUTnLqoq%eh1VT}Ni4WJsq);A1C#Xw`W@&~SS#6GIzo!4 zQU1c)7jV$VwhMmP=xo1oA5RI#i{o0 zkL2PcQzRVF*h+ZZ-anoO0rk-g^C`6yC!rcRhldxIG15DlGzi>H-nBP*`UGlYGFOJF z6}C$WXm49Ks3$IDh_~XwVh5583cTDM=V}&NL4S9yy8ieMPYeZ^w2)Fmw3B+avuFn} z(;rGIQG!_A1c-q>7}=gkmt^GN82I)Ox0=Uve+*tbE?vyRlhe{P1p*)sYQ$Xt7pqk8 zp;d<7^p?3;cr8w@uLEoYx7?#*RNh!aXg^{D0@jgR1@_oza?^p>VXu*(HS}F&7w?ON z;;=3O%mbUxLm37SVHUlfs(r6NFZmIYuL#;QIz3mDnZvK(fhP#D8EG(oe9C#7was!f zI!!~sWo3(oZgIctEgl-|L}#S;V4rc$3=$m&G^UTA?a}k?>=E0?i3*x2(wIMr2mDe- zCQW~TUKw3MIeo*q-!=^mKY8IKu4mFw(39CSm6C7p@g*Xjo63TZPqq&l0N~L-#`MLVcB3mhQZZ#cqk4Gfdj;?RbaW zk~${%n}yf=x(-J@^Q>tr$#}%AxMkdfD#LK4x2h*&)P38FUk<18yEUXK#?SVDO)LKJ zKJu}Rpz@Qn6U#~u>3;ESx|cSzQ*!rind4RL2O2sH`hI$eNd7MI$fB+>NKP?s_to4) zQ5?VHS=gfqS_2D;x^rjEGqYl6S9q0dk_~wTyf1H5%N4S+{qR;e*PNkX`{bm3was#Z zAPk20UGIODQy|BGm^OU-f>GwTU%sEXgLBV6t(&S6k!E3Swi0M9ypnWxR*I!A(ECS` zl-E6)n|u**LgJ=PPD}C0#`o$xryjU;J_5lvqC@BXpRJ$GH4OB&a4P>8YdguK{CE7t z#62CoY8mIVQgsu1lnfsII&n|gq;!67g~$`Zspb}jxs{Z3>y)%KIZ9p~Kl%4>WSr_e z)T&jyp539|c)kCUKZSHV)q%@iyCi>~ACI7~eG+l++9-8UyGWt;-K0MHb-^#+h%3Go ztJ!y)Hgrq8mDh(O)bj>+gx4P$DI2pZv|g_#%&^TVGj5_&M=&GWZ`PfQE=cWa;N35$ zXlj$>8MVUc0uN-w>mZK)162^c_x?+|{LQSvoXPM#Lhk2>Kd~gcLkz=tu=N#KUp;27yfhv7yNL6`*faUHXW9IT}V_N;p znbFpgLFy;JI|-Pat2kLgo4{;h+s+yswugahci}4a&f;7Uz*xhyP45N2WextgO>=hn z_RmQN1SE6{KKl8`xAVTpy7_#zaQ+?fR9CNVD0iKUc#!KAw(J95efgVpcVE(_IQ&Sy}wX#xKJhR+qVJY`{BK`5mDl04KG8;u+*#O5yYq< z!De37RUfJYblL)obPUz;gK*I65Vpoq2Mp)vRXWw7ZCB2)t)|^@p)XuHbY>)5@aE*x z+UqNqac_6&is~-@x{B+BctWeHs`d}-GHc>KkvlhkI_Y{Wm=>foq^I90CKgs*eGvUn zzph2MlWbH%f-c-g!A_0-=zQi3!N4H76x2!hNWYHG?MIIG<#cWplJqkD9CMGd{#;De z@r4m7!2bK=MS5VO+EGsyNw7W0+5F&e$L~4qJdzLKm;70hjsjhWpw5SWqb>nM#K53t zcElLBXUm(M_6;YlFUF)l%q@RtkkL_n9P=6yJSH}z9VG|?CPH=MH@Id$8qTiR3Dx4! z{!eu|jF&OUw$NkCFJAzG06157uuhGQAsRL~2W}wAyf18}oCaaG%zFCWgENcbRbs(Bs!mAo{hvPS=+S=YmiDv2@lk(W0Q&;Edf1IRaZnd zK4UdWy8?X7dM#g9C*amC=Od@HpB|F{kptmP!h(VU2t$bf{{4H{N|8$9P9Y(b4{yP< z?klnjiY?JtG%ypQ=H(sq$jUFZDd?$NrxPVB;{b3vI5jg|(U=)97vF&ciAdK$Ize(I zvvbpq{33ape- z*mZQ!)n_-wJxlla@i88c-Nth`kS^pv7ZZBViGz!sCgfs7M}c|NDVoQdS^%^_clFi4w2#D01EmI=tW$&*@6%6b z6)sISp=*5uL>ZL=2NNQ66M`}Ujl+&!P)I1{b?Vyh8~^rJe7elB;ijs09Q$-kd_3W_ zqM+{Z&4SVorJ`JeMU)_EQ80?gV(6Sa$i~W=-m>vpmRQEWM{t)rYJ5z!cG0W^Gt zuB6oG^_w@H2fufn6M+9awYuptW7vy%`sm7M>{B4@pclinjdICLv%n;#jvGY_%po!V zT01fM?tosX^YpLU$&ZTr$9AyU!^^{1CDoVGXxqBEWL{B`)TCFQQoN#GO@{3P$%54W zp}fhqD)$6~`1aTb@aexR%D8CpNjbD%kD501_XF?2v0;YziOHIrU!Tia%2G?TZ8JqSo11a)2CUEo!rgqmwprw{q-(nPgeI@2u3 zDQd6QV}``H!W1OE8�j|4lhkfI|S@CK`GapLGiNv2LObm0CcPGe8qr_YwULv&?=% z5XaHDZsW!UyjU`jrQ9@l3sX*1W9V)(Fc@nt*0ZsBRCxm|4W^X@{7!ekt?+8q!P3J- zlHkG+u$`1;e=&2WWdXm?A5;>ql0MEtiM=k}ddGfal`+g*b~_Gn*${?1#F7^uRpH8y zH|WYGqweFFMFkov`VD$YY3|bV>$5fNkTTx%to&r-EedloUxW>v2sRh$ZrV28=&+}3 zy(FxDBORwLI%NtR7^^>g_M#afe;puo@Rhru8ba9rdj*EQTfki50#hD~`Q)8^1_hLS zV3%WHWaLRW+k^@Ot`6EN1{XD$<3S`!QR>3 zGX%NFYQ9nlf+umDk-va_1L?1o_DKLAj5%G%zkBX!fCof|uO#Lg3?NG9FI5Qk!zjbV z_ZYGC!y%jsG#gzKE7*@9i`m*G1N`c?zJ#C54}uI34rAWj}8Kn ze@?-%nU>abA}!7nDT3=mbiB`TY}sI4?ivCX7Kl^Wjq580^zj%qG-uGPg@lKv;gH97 z+5YqU6F)z6-FaAFrP;mO?#}Do+DffzefkmHSnl%XICQ~2+lzEA&evf zDjsz|$LUve`YGgt0jG_yxv^l?%nvWv!{Lwy5#8jA=+jJEr&96)2t&nAEC7(@sfNiJRwt-!lq$Zz ztx?%a`K)?EVc;*`O28OI>cPwab2Ww|amnUvpjR)nbusx;* zn+FsC(hBsKyB&smYbFKuoeD^IpsOLT;w8j9TQX;;rmwTWsQ@Sk&>=GW#KDZm(L#K$6sM%ErHu4R!#QY=LcjywDe^&j1}_7cV5*mqltPVY%w)ywm&-99m5PuuE@wq zq21zL9LP;3#=CNaUyjYVsozxVk${-pymqQ{7mtHO zsJ232KxRl)zx9qdllqCMrGZc8TH(cv92r7;P2<^Gm!CXqOfj++-F$<(jnPHn54{u< zOIqW^5uu_}?`JMDYwEwgH0{j4-|7(f?6JF@8LPanY&J;`fF(iq0!Q5Dr1qb41wRiR ze$&=(f(4u(N*ty)9ah`30LQ5CQE2HTPPm;^<#mN z*~6_xSm%gP{OeYS>}(c3mUW`;^6tJNceob!J)j+Xx15>Z-CGk?c{(s~e)*s7LD{~S zv%~VKshvp^9-`jW=TDhLnWp<3f5;QDuJz-v(d-IU^crheQ~8YpeMRcm4Om;#l=y19 zBp*JZpfyj_XpP;V;IQNA#V3Iqe^f+1%Flo=K<|0U%T=6G{{(f`U*Oz2AL#wnjFNdY zbs?Z--Qe|}y5UA=hu{{G^Fqw}88`dcc@EwEGwa3Tb?ltupUw3LJ!|6+aTiIArtn&q zeK0waYJSYKy2tU$r%GKtW9F#n#W$Y2ZolH*?D%n?7|ek@|701x9OTeA_CR#xN#G#8 zbi0?=FWHkNIu8`5PLI5YF`UrZny%N^nf%Z6HLw}&XWEgXUJ!>rH$QLxz`&{Nb#qgQ zDufwceNHS#>+7E#u0EPCEJ$+=X8a#+zR#iBZo)$|Ty*Frn^#VJsVQ}gnW&VD%y9~3 zW@2e@_Nh?tkZnj5WBiM-hg-PJTNfnz`>J|^>XW@_E=@T$OqtaO>L(t*&27>ek#dRf z;WBP^nsT+YD9q2VuCA5@-Br3=82Q_?qT^Y2Zy6id z*lMw_a{8Aw_oGu|->rBQxJu9D?cApg2P~iGYu$}s=#N9uF>IP|*i0u%E9;WJ6a9l* z-kbOmZiCT3Svk=o?iJ60{qZ>l>9@BCQ2$9R+`BqAH+x#GV`65owpOI##50dw zfEvk0?~3O^eg3H$SZ5!FMabM@96%Ba+$&$TmNC4~`q1|<0e zARDT+Mi6(2s1YPqdfU#mcBk)qjgKGSP1l_uocHPMp;!kDA6$r#95?0N#^FtE^<7aU z%Kq`W!RKELF%JeoBCKuh-#IN1r-?;B&O(qLiZBV~m5RsERRz_+4q2ydIF%@ALUx&? zZ`v=iRpZPX{|?3h^fDL*@4>(tY>yU@r30-#W8y-o5p@!yX>eTtgn>(s8$WpJxIV!C zJve+(S+DanLv3DnuXtekHk-6RiBm+T&iTp`OL5%9v-qgs(O}sA6&W(ANV5ii9QEH* z*ck-lj&}3mTtS13?OLR1z|9@6@cF1 zCfpS3Fyj4KBy^U71mHsLM+i3nwcvd--j+g5o=ATFk0|f#wHx@tsH;{nH%9i4WNo7R zW^_?Nkuk3@9Is7M^kxOfkCJKvTxY~E3QBO?FdUib`YZeQs*%w)TpUf#Wwh7-ZsfTp z<(sl*@rDoADEbott=o`Zvj1O_B@yX?Kx~hLe|+&zt8V7kJu$1041#V|%4SGN&Bz}@gGp{*h*$usI>4U;*EMb;!rsSq z?lIoxhb~nbl&kZ-;y__Zq71-WOzOJxkBixh*dQGXR2!m;J$GY;19JIVSBJpV%=B51 zeF=4*Vdxc^^FW7!(G&yk8kBQ8pZJhqFv{jUn;LdR(1OT)74-`&GPd>f2vT%(bo7CW z127tLfCX4K;9d?f$QZ#QqUyANfd@xK0R>BAWIlq13~l}K?xh9Yl<*`pe&8N8fLs8& zakg#l4HT&U$Kw$vFfgrU%=|z6pAZG1u>d37H1|E=N>N~>{;S(qS%p`b&eklN6G1?| zEh4moFqPOI_M!ob5L`a4MH#?=<}6J;z~=fA=SxI5=$A%%_X^}&XqBu)~f zo`e>F_!bEMDtqb<#^Fh$Euu!u%Ik!B4V;rP+|F3|<=>DtXD;Rl?AFjD~{<6lQMxgN+6|MZLks zN>Jh9*v5r_I)eJw9t`wrm~nmil$Un`_EHe$TC-X3?7yIqD~7SNFlu;U;eUQjANKGl z<#lwbD{meBg6#vG8(MC*&m>9Y=_u=5wVc*d7>l5ECQ1SnzG#%89NP_Yks^e!VD^LP zsSP*IQU95de7tx!lj6hIY}uA{95W*@xFvq#-NYaZ&w=|)2C)`{79B67pSO(EP zU>2g;Mhi&z!T4cSNO|!K>;VnJV+h^^@Adb7(#SE-Yx?X2o&uS2LmtfAO45f>-$!G2 zBhT*@)`C8aTW+ij3@)uMyeIx45^;wq7ICisCoV^WjF;W&RV_z1eKFy;9e&{w&f_Hr zfyvG%U?fTYJ^t(-8nL2VXMTS^RU`$sB}`cQYHui$<%Z_;zXyslgoH0p-CoZk1)3Z3 z??8A03ynM`%$**Qu7e^Lw~+9d5G6;!cDVflm>^U`DB|kT`%9?oz!0O(((NEl8%pQ( z-Rj^|6Y&O86JL1emzT5Ntmn=4x5z)%0{-i|xmjI49Y%az^v<{*sIp0FG9jO0X2R@p z%EspNxhS2|n!Ddlmm*_EY58mP;c8Qf?LBW)<#wbLtHD1kzR}oN;memT_w^V@M``Qp z8C!HF8@QPAgkMZLuT?&v?a|#(Rr!ybUQlSqjsxAHZ#SSQ^Zxa0&a2DbEka*G^4yal zkL<2>zr!m&R8@#-dVI0C#L(Q_W>=P-K27&(^sU!nU6vm;ue7GRv-R4B1(baHzA?C- zPBPS|C_iJ@WNQ?^`s$8%R)IydO^p}3^bVo&vu0cCyH4-+dorRsrcP6)Se*MBq|zH( zD2yl{xlnF3-I^^{Zg^yCFN@~-CJ7${n}XK~+|nO)?VmUs)m%vmdab)7Q9nlffM;`1 z$Mn~tEP2K+mzoCcS$fRd-Qwj2_wD!YKIFMnR^lc!apYHR>v}h9j;GV|V!})h6h?QB z`ED%ueXS+NY^QJf`~_v!oI5-At#Z_1CGHplc9 zHEL=EXjtCA6G|KZ-IYG>s2zuLhH_M5V`wQqGT?+3h1FMl;*x-`6zUU}`i z@Hl_k_8M-{9Q)#udY@~eCplD4Uc?8gUj8v{``JM zBUXP*=pA!*&dCMyy+&(xTsb(7VJ z-giNx-acv*9aWeWpRDQe2oHr>(L%+6#@8_hknAwArso5*ilRvQ=@s&ZPWT{qeCny? zgM|}m`S0{}LpAVa!xL-bmSlApnYXu!TMiBVw31BWD|Aocs6FA>oX6VGDAZpnztDJ> zjxs_wiq_}a@0w^yi^`A?^L>{&>A2>k?YyLnjUys#n;dI~y5^HMf!PAw86FmlT|5a{ z=P)NgUy%V*ER^}66$$q4?$;`O{+t{0cfj!^MjbqHa&3c5z{kVLz`$2cNvJUJCj*5A zX*L8CBmrC)ffDN$xIMf{QvR3lC)F)h*3BuSW0&l*kWMNJ8OlE;377v7N`2Hn04)G) zWgw^YotpAc03U`>d!RtPPC-4VyubblQQuLXFrXWfTSG}imBsU zcECDTfLc&wGn-MoP25W#G5g^Zt`-zkT2iV(u*HJHwsO%q|Mrvr&jsi{ib!Qr8GsW& z^Q+=APgrPZe)i7saW;L065hZFpV%ubvz#2Kt{WbuEN+1rnZqCp|bYvAZ1 zGT%Y^`tf|_ERb*ar^oP(lL+Yqh3nz$R`=w=+B zvWQGCa0(={7*}dgeQw0Oo_S~lDX94Ta4b5v{aG`6FT|+sRi6kW2a9KTj|a#mC@rb! zDy+*6!Gad`pxUEv1ntF)3B!vc=H@S}uPy*|2I1O-&3tcn9Gp2otRl&^7d!O#0zcmG zIm=7LvLG8%+8OK#-YqE!^$Y%xh`9KZ%1XMD%XIPsx>r%*#r^+THc_ZDF@`n zl4NPMSO3ts@L}o!q!!_>S7zaCa$M%+l^-8>c{~R(W-004G7knAJCv?fQYwm8b-NIy@ycR;`8QrE|A_t@SVc%e$nuvW&7AqKB_o3fC~O^nIPn zgYgQyXT?uEe}FGo@3EIXobv0|aAWlRt_qjG2cwk6ByrO$_x?+h%pgpBbZ^{B=2;|- z_sc>;eQxY3qzFmc$F>54H~o&m5=5;Unnqw4VhcCA%I|?-ru}(X+ADAo!7?YQzObtz z*&TRg#77S!1MnRgI5}^piV)9XoP0oiQ7?i#Mzh-sP6ZHh1Q=fefv>3=N*E;+M@m3m zPz>!skQYo1iQgoF4%|-`i0Cj9UpsUh2Y>b2f%C(_%1XkgA&i3Pj5KwWB2E_$CqH+A zpCG>X|D1+k!a%&_09ccSg=YQ9N&HS2#cU(OSv&wnPpU5fffV!t_@=RVoFi=xQF z?Lm_d8edRRSYqhO36~sMext)4nC5|(;ETc96Zwy^H&#b@5OTwP{iiX2BH;I|D;_@Cnw z4Doz*6*Y}Pw)uvHqV0Hccw)w2!=golYB&W?4XZ9hw-$Y>460bn%C#ZTq0VM1+n?g6 zPVfN|JC92PX=fK9yTDK};KeCE9BGjFW4v1lR+eV!M^J37twWI*4aft5eeRzw^-yM? zxby2Ql3*|&L#Kh^24k|dhQ_+k)LGa?%p>0~n3mQm7EYLs;1QDhiwF!M{s9;0DhRn4 zrlp}{1IHB;kyns-`ELGDW4+jhp2QQ}FTys0b1^WzcA!#d5YTvG_(j?gHbhB|0bzOR zBnuuI0U{=Piv-9N!25_B7r;109NVNR25wR0F>wUD)e^8C{*!A{lQ28$j#1u=>aJgE zcwCrv(+>TGSg=9+D95?}t8Jqf$f?B_R0?Ec;r7B{RE;xd zzqbNEKh;BnLsgfj(Seie3@sX#2Ii~wP|RaLNW&?~_>%10@e=7Rol zdoxZaGE){1Q1|8QU;1D_GDKBbAVQ1FIitGz*!K*O?=pXO{^;w{4YV+2PM>^S{cb8{mL{i7&@)MvWG zHT$=knlAk z?nZxCg|oGvw4fJ%rVzOMs3Mfkji;XS-dHzz@lJ5to?jB1-3BJ-yiWe+mUQW^WRriR zYn|}ppeL(m?Hk8^9=(Ts{E~v*L;9@)Lv`a__VKKp1=qAa*`LCB-lBfYq_+0xUI8H? zA3DyTn+#dom6%*VfB7~pEh(R^fvFWdq_CFk&Yo@xtnbyBN zZ!72iRN(N+=`Ohx(nt3t1)Do`fs}TBwUWb|`bqla4$sFix3?%hJg}p0_VJ@ts!I}q zw{%NgGQ7OLj<#I;^)#_bV==);d8tfv=bT%(=tbtrWyd{N1|9e$Y*L?SoLsvr@$R9? z{eykCx>Emm9uuSKcJfW-k54;1?J?n+@@i2y)vh!2NvEk|ON*Cvsnmzmz(}Dag?HYP zA|6wYM^wiY_Qj2D?97&1A7YZxsryV=|7Wki#~7uV{&pt5shv3IyJ$IJ-F>}(XWi$Q z4E8q_y|4eA6k_7YNys}lAd#{l@+56XNN}6+(VW+%@oy)szq@NA___Er3ybsm@cQJ& zVbl6gCvGenpD5VIwVp3!L(zVw>1c^nMdcIUUTr-)IW5!DGD*+wU6AFHr?K$O>)2VK z=hwb|6V=fAvr9Az=9(SVPeV3T-E0mHyYy!m&$0H=PefKIEi8HN^Z9tVBHvL>9q!bI z$GMuCGLs|4)ve-_K1MrVN^Tcn*p9?lul*Oz<+SLxb}=vu4qrY|^>KhFJUd&IV8N01 z{~+Rt7?r~4Ur?N*B$-8%!QoZ1@ILseyaC`1{_`oE2!s~;(W6735SQT=*kaFkftcj} zkL>}^5~0;Sr2%~>;(5D1CqX8P|Appp-OJoX%qzLQ#`#XxWobg&EHr4vQmQPt3xH%`U+~vIMx`u`reqbsOim8T5QfZ?mY?(pw02_wV$4TJ> z{>mtncsrqu1mwub!z1}b7DX~KWP&7;wFfdeA}a>o>Gk`EKBg6@DM22vF-es{RSv0S ziRauG-y1z_PjDdP;^SM-Qh7^9?V0GogPdV6)IjtCIxPSw7zApfV^dRW5PJCmtO1;u z>ZYHe%Uxl-`67(`1hDTgZ%&MIa6axha5-0sZMB3(BQHE5f8_)h{GK^D#9ofQ;<=#* z2f4P;g3L@3Y5~#Wp*Le>tPTT4VD(hlM&osT2PrlxdvLZkE7-j%+ z?*(+_uDaJIdecOAUZOUytYaT8GfWlMAJd+PtK<<=Q^=v2u46zj4B@FQutJOVjLv2u z{Rd|525jhdajS&z<|;6^tB|4)Iv%N;$$x#NCKf#pf+YZ4D-ePS)J(qCnx`uUC@*cT z#F$4j?His24L$w#H6<(v%;{KMTmg6!r>&NqNlj4WLo_WdQM&HLUAYp{SBtRqFJbqmSOeoR(Y8Goo56i^sHOM}n>+SHU|92u}F#AE(K7emKq))9w z^M#58Oa@-joe~xFRmYqXz7y9cJYmq%LHlU}aYWP16{R8GZZb?sWnH>&L;sxsvJ;`z zqtgsHnU^Yps%rymvN|U0VJ@JCAc16GKx52fP5P=?6g2=kaGi z(GS!R%hY)aZFULvG7O}Nu_AGYg{l^TG2z660aX%g_TK<3U4|bt)sLWZ`i^T3 zc|ijU!z_leWQcQ&)6&gyP}adow15~v%#FNy@wvF@bQ7_(Aq_ezRl?9GO$cUrUjTVC ze&(crK#BE80#ITYg6{!Q~S&vUNwgO*23qbtI~!(V`Aig6m+L?Tkk6zSe%-2WG*2;xHz zAGa*ZKew<}MqdU$CO&v6;S3{phNyAr?|(@`6HYRz?CwU!#`U1bFQZ^au1Y#)1_yt( zsgo{{FdFf6Hmhl2Mn-NOaKW%k=D-A{xxt>NzkNkaenT%dYBLeaAuWOE7@*y8?Jo@h zihycY;N_zYKVToPp%K7)1D)gZKa*mRY$4*}nmSIXb-sBFyVEnpWr%7k8uP z^v?au{+#!^_|K0$lt*I8mkZF@bNj5eG}cr5C|vrm$0;KI?D23gB3@2KFHMTQx~hCr zxks~Wvl<4jyb70PjVyL9OkX+se5qmA%g9vU)+`jcw-;_4+EizpE@pX{Z&xO zhF{Q<@iRyGXiTXTyIJ&N0{f!E4Y?&s=lL5mN)UPcQSmP|RfObVy#IxE80&@~)g7Vl zC4m-;T0IFgXE8D66KEaNpZLc=`?x+uYFumc@7pCQ9b%jZzS&JHpx|H9*1cDm5(k$9 zz0xyw_O~U19a=e=S^xZ<{XMm4bTp^XAt!9y$&7KLYUWsj)G8cI={BxTpXJ!K%L`w^ z*z|kGa~sDt-aXU(JKXm%`~2+-H;^3po%6nAZrj3&poiA_h&fL)-aSg6>%QuM1Yt%u z*&1U|zV_v08}DYvXzS_O)As*VxDFr(=0v{xZ2K#N{&KkwMUVYhcVrg;u5a?Cbex=~ zP4emY?VypC&NlHp64fMNqUJg4x4F(Q6-se`ak2Yvon<}`S9&N-cow|CD`w_gJS=h~2B|ucIqPHs8c^MY01d?{Yi3sxsQ#Rr9#lH26b4DNft; zZOhTJ4@@k1wZUn5z$vv&;0Z%YiK+hXUxn6>X7dw_g0VMy~k)j#KvUVy>vBq znEy+0{;y$RdaDoSf9+3nm>`l7qdODMRMJ1w-QCRy6A>$`c-!krKsG@q#r$&_1FlN2 z{wQLk*Ets>N1Yum-I(JC2;-JG+*f$=W&~|zuSw(7qD^=0kn^7>Zj5rjBGUu&TXyXV zRl@qK$|;>A+=iz&c+?M`@T|gsC4s z3#8y4EL7Uiz@rBbcvpT|yNymbd~ZRNAov72prfm#9mb&!k|zdW46yIj)(+nH^_Yv3 z>qC^IQ4tX;808^a;i)UC&OlVh?C--Bg~5b@h>eakFlCc3k z0zk7u=sgENes>w!52DJ@U|9gvU(x`(Axul@nD|@3>HPO~o9`y8SW;6z`__LA=8rXB zjdFx^R%;g-pNjm=z3Us>cHrv#6E^Dpi%#MpKOqn%YL21H8@`5h9{3_*=rj z=ykCkzCF<3MkiS^22IzCS37(5hgr*pRU0QeoYr?RPe0k4^ND)*RVyHaT6%~mrcvcUL`cWEj|fwVgd_RR9EhU$Q@HU+ z#s*kM{6o*s*4z$u@2mx9MNLV;i=w-XI3zu&zUV2r)ezOS*wc%1*XV5LJ*H9H< z$oi;6zfuq*lh-j~s6;c;({H_bqesZ9$Pob{dKj=BqgEwKL1Owt<_oZuBGJkvxY{Y^ z{d=b1gdGbjz%XRB;A_H{DctkD%k`|hJO|wA#1IqrCP}@cR3OeMuoX8m)hAah353CZ zVA!z(pA#d7Vpv}Tm8$Xm>qMiI_*M(F9beazD2E6&fm9^Px<|t~_eX(iO;SsU-!)1; zVl0Cug%}RPBfPud#CvkVV!DLcGUD!_a=nVbgRAv1B)DjYjuy@x3TJ--RI3HP?9$Q< zC4INwIx;PceG%o$3kTTefT6l}R1+|!B}j8yP#zQ6av&H8`S}K>H=I-=&;~)x zDhy(bEIb9Ao{;x}oGxf~LT1;Vde%fc#6uU;no#|(9pacV0zjeIcaP&hcp(1*!Ki;2PSwl1RSR}i? z=G*?PG9e(bgiLX{N^*1R2{)Jmqep)MI1(=uAzHQeIKAUM-JRCj4o3VPUp5&C$_K`Y zh6HXr0(AvwpU}52W9op8LICw05MwmxccBu&D2uo$qvQu|qqo>0EJCDuU+LN{EqlVc zXiwF|lS7=q8@xtzgNG}DJ+80I)G0qBmGdAB>-FKY2{)BK zDdEsvt81IOG5%0h{|z5+15U>CRI=N(yT9DJrRO|;IkmXfNbN?l!B)#{^{+$eu3O1J z%kEi=H&b4=8l1@-T6XPKq~BPgK>vLBJXg+|`b^NSEQYS&uHDn47eLxY{99u}4v(;~ zZr|LpkFH)@>+}1(F$Zn?wfC+qxlm6yZfI!8)cmx(pv-v7+G9p)eScr;S+AqhDhuWR zma6Q&NV`0b`=Bp0GqVBdXfUUVr#P!>K&5Co>0Vr{c{O*KI^RBvP=6!uiP^~NRsZ<@ z({$%iiI+*@?~9I7;Y^9!z#W;=nh|E|b0fk(=l~o0{rhgo;q|&a?Q(mr{k=DyKdI<@ zeTi+|@Ocqbq_#3yGG~${uKz;~1Vq+Ka+l>F4N7gsmyyxW(++M{N;CANm)DIyc}UTd zn=!uFGTJmFNA6UqL#~EtEUR%I3%8(&#c*1x`gSC|xdklVzHmXrW3HHD*dnF1>f_$y zk$#V?QQZt5Gte_vQrFY7+pyulmv%?x`%IVj5)YFsGeac1ZQ4X%Rh9WXwkGArD}9;e zi0>Mer_`g_z`hm~O38oUm>|b{(l7)VG~+gh?Q-n-3f#wE4c>H(xG7F&X7aOcgvl!x zc;P9A4a<~PN*qCrPpz*Q2QM=n*Ex82r-aMje=X$?LSF5ItkL8A2url>!U7^a?RJIF z%6>am(8Vskb0xmMJTBhfklFreMQ@z`^PG6eiI#KQ6of5`zMPg#Pm)_%I`NM4C=-zO zvlSG@r}Fi2bXaSfcZwdXW?pR#(FXSxI%yKQ7Mi-=Jq2!Ba0$dn zuDUrNI~(Y_#AFpmLi&pr>%Pcz{IA*!c#POF6N}d`0NnXi5$h>1+%V|E298+?O^nXY zu*=^HcM!D_5JGT)*1NQP4KPo^K-zy-BQPTdOV4nDEAkIAvaFT!o4f#gcT@hLqki}bIMK%(VD zSycUR@mflEe69D_dm7)G)|Tvg474YsMBD8k+vR#OmZ2rvT8!X)Jny z7snPts#`6sbvQc8pO3%$dUE7i)12PvmGid+){dmKk5H?I*RFgfE^Kla+DJTH7L~m4 z!9I|w+r8%7)LT2(%71-MPBc2Y=ym#8&%4U3JdPs0K=_#;G*2=rWob`JUa;g|7~gJ0 zedg$x7kMfHkr;58z?(TGxQ^z0bUBKoTeoicjkhxThlPbnHgc9H44^J>Lj^^1VO|f` zFIUPYXdyRhH$84IbHAhDEHhE25##OMCyIqp*a0md%wQbRTvFB)@Dkw2=L;V{i*gq~ z75{w>_!2QodH$U3ukYxdN;wty8Bk<*MAVuE#xsmAqExM}R);+S^|!R5rf)kEn0sLKK$3t#SSnLvGkCU{;AC+R)c54v->9{e}4=E4L(QOVf1)9Zwgwo{K^uwl@VHcc(d)FjFJ`0N&#rVQtD^Ydi@1Ob-k#9cz2X>za`2_vXB30(kpVX4N+ggUM0Jh5jp zar|#83i|^nSl<9UJ=gY}i^pwjL?&5Vb7Z=1=+wzn;TT8tDl{ zfN-I07tyAHpmcA)(=!Sc_&$BY901aYM_8!1S=c22^<&~W9PQKg^(!nba*yqTVFuye z132SfjE#)s_bXmr_ChIv$ODpOM*uYhXMntAG(!74rqk2)?!aOYnm;leMOA~?E9m3;ye<}W1~A%yz5E(g4;+im!ekMbw9G1p zVt-+11eikzb0=!?TL=K*X@;K~M1xr35`?>?MT;$;SAOEDC# zff)gwJj1ld_(xQ#iv{BJyb*M;s05n_Hm1xV$zm9?CNWPr%ykh(^fZ>Aak&agB@Zpu z2y@EO3Afd46n!(9AAW5&c%;Q(Hd3D2qT`sC9DP`2v}^0&_XY~FS0@qHGnpKHCE8w& zUI7r&R*(Bq_xVNZ8i(Z!Sy?6Y_8mNb_*JpFfSegXgz1#NFql1gAfr8tac%AG!jb4NtK(e*f6mtxe-;bzOcN9LT)g|> zBPycWGe?w^&2wB#EV#v1gd17dGdyxB{&tQCC~T4+LngqXO)uNf;|RpWuYo7;uv zPo%WhO!O;n2(ychaenfQ``?m??dj{LW)0icz6PoJ+IYA5*=R-8H_a7zgj2_#*mNO9 z{*>yveK&7PtaaxzUXB>9x%PT?VWs&=M$$dbqJDRGnd{fZb+ah9Iy!o!K2ABYm-ydnVf4bq8O4XK5Q$d{NABAICe$E(`zAJ^1L*8-HU74RF5*Ly(ur-KkcvloNN~3VV19e?aFJax{Clyr z7GCPQi={7cBqgu3(DU)ZC^!qlUD~;qYi$%Mo@h5`T!95sBHZ>$d@w2SO6sT~Fn2lEYB zD@j8103xI`=BHo@Ek0OjE$o9wiLNX*m8#*%rIlCRV=5nMpFe;81Rw<2(=hXZH|^A+ zUv-7xm!R!i_VM+*o4e&(%179!O}Nn%7@RJ=ufH7(IIv-u#rT++m_#75#jN1MSWMe9 znAEnGxjP|d(;e1C0opLR9>7FeR8DSY@BKp-W(KD-cp~bAGObX7TIC$|vU;GAu&@B* zj^l+6pUtFaAIP4&c5vy1ISDWYo9bS8c$Y_B+3i}HwJZ7gMNdkAX4KUoIQ+PY$Ya*q z>eL(aFpg>oom~oi|d!)OS1Xa7{V-$)XF~B0|sa3GZ~9a3%aF3>;cA&32M85N3cR zq)f8x-_3n&7qrJL`cC6SqBJ{Y5%;KfoTDMK@y3gzp_s%IVi78=oZ&h9<75)IT~cy9 zA$tF}QgOIOfJ6`NvN}eQ7&Q?@0gpn7>-eYM)!DsZiDE|eTvyP+iUsD@=llYmn$xcn z7ki=YBO`yC%Yv$+rq~<4_=FBAq56pnk!3%9o~ z{hC_vakZ>S$eZg81|cm(E?p;r7YY*bH}MBeOc$EqxP-%7GcIPk4$b+Q+YNStsDm)4 zCX~J&?>S2%)*J90xSJ(2MJ=<nx=m-){k(5`<_S7v&H~njdkfa_ z1C%cG{e1vSHn`{2qLLF`tw1e=I;<9dhtJ6m%}{JC8yx9fN)v&SIx8ebx`THafAR=* zh^))Y)*;jmF$oFcs`+TMOuE+%`|Ln1r)5Hj)@YuHyBR>VnB3glte&kG=1B4l!CIi% zn(B7#4bdxybP!cS2gXsZ-yP}jS+KfoU#{`Z_9!Z6j(k}2?NzjFNx$$#L|8yTmCZzd zPc7}%t2OMo(j7loMe}$>XZJd?dKP?A=zJ@pw3W3kES%?Rl^o5@p>995X+Z&F%9-pa zS?{s&h{pXN(kW_7E4x)@-OHEt-hN2*eD9Rx^y!FlvIH6Cqi?|j`V-N(*E#Q0b|?n$ z2CuO%a|uRmh-BZ20Y8veWAJ%V@>ECrY1VoIRY5_Ds_J{NcFne2y6IFbz6oNsX zZIJX`fWM<2bdNul7p9=ws<#eGx7gyMZ~Al2Fn_~!{;)N<`}%MviM(@XBhk+R6GO=P zI_Ro!ujuK1@{##MS#quOR9W!@T#2mu)lEe}(SFX*iEq*u+H%^Io z21qrrX}iGrRPld~IOmeaj2!r@VU@lYP6(F)BLIQ#e)q>SY}xe#J#Jib*pF7AiSv;C z8eYFHN32x9dtA~x8FAC;Hz<#_=Zo@D{Nk`!G(8wv5T9#+_k?2IecTJ!UiDaADihv{ zb#(wexxQv*TArSzWig|$fF~?+yglSEeu8cuZY(UQaUs~mZg#W#a#~haSX-MpW~2LI z9^pG*!#Cg0x26jMFvNlY)&U-ZsVzex7`^4)zN&Zcr%;jltWG6Qg`HXk#!dx56^g`H z;xs|c6ON_b%S)AOibSYiYe_jYzRP%{%})kj4|ap>C8nOiijY#NbL zRNNX=<2XYZ$)KWh$uiP}QumR@iN8DEKW4@JS4GhO$>6ETn3na%%Xzi-(%}|S?l{> z4QHOqL)g>QHU8?V!u5N+IY)h8PEfIZSxDeFFrdsYFgqy`-hFDGB3xB^|;?7B4$_5IV?>ZLCel8A7 zya~9Vub+56<;8B@@V~moDq?sy2CS7 zFy_P|CwD`6kt^F%;db=={uIwXE-lJfRu?V_ zEi0?3saZ<4vZ_@2rp?rGaU6+%ZM#F~HVBR;@$zhcKlDXcKY)74`$LjFEt^Gxzt`(t zqw|?JyIw>^+1UNFE}wkyTI+G%#IX2FDoR|N_^}uCR6hT6E>xJhaxdRe4@&R<6(7KQkQRs1kjY*=49=h7Nn+B@!PURij`e^ zWk$o~O+Wp$+_yTH@5EK7Z13x5LSD_0%|9+`?Yk27k@m%l#qXV?8+?*$^G^HxHStJ( zZXWkv)=h?~`Qp&n#`*z`8KVxx%AB>?*WabQ?(v$A3C=Gww@ zTMi#FtG%sye?KA=Q3ocVH!Z944B*L@!En2h ziQn+^;b6{gKWFPebJ1Xi5e%UqKClCb3F*Sq-WF$R+=LhCziFAc-=d3!8RXWDLDWc= zAci$yyde0i4@%{qM=KZEADudkRv6$BQ4sX5jJW~{=7&n@9m&x$yDIuQ}Ad&SJBoRo0O>)KR?<9`A_<~aj0}YE4Cy2)z<}o-R(Y8O0Id*b3@U+@VeB%Tr zOuSMBMkk}cUv|BBT7Td1QI}F-ea67b5lN|xbKHrEiTNcu`+h4ZDgvtjhhGX}1fS{0 zd=0%JckbN1YL((~xyT|X^ewxDl@;z}jnsX}H{ohfJnk!KeGOU?5WZnHMg%5gN(fL5 zyFH8mWI?p@OS0T`v4Vwi?Ij|3FtP#5P=ICq$+qK?55u;E!E{@j96#mQ|8pMRMy^L8uMlj-j6rmh&dhUrifni{*vnQMYFGdXu@x{>aI;5qU0`_HliqBU2Lxe5#Qm$g~u;mFhTOv#h}1uY-IFbxB%n}ewXRMhC^+C zCTt_ZaRB#_N&&1zOH0e?>FIx+XOq#<5J3&;Re)f{a`mLl$`n>>(4+-Gm4qi<-_><5 zE(zKt7YKkBo3MjBZ6Th z5_?U5*1<^ehqHoxt()&QDu{pa-1b+VzIt3_cz=gAmu7#u! zCgW+)e?U0^-W36|h(9?3-Tpurhi@33dGNBZQ0nscec%pY9+@9~7FhX#rN0Hys$P&l z%Ut~mw_T>a-&I{)ia*@uXaX>e#RyZ~R{`FvXK$^#UBELj4;yy3z7SDkvpA5?FfY7~ zlH}=A5^#jEG^J9#xuvB6_B$kp#}4e8a6fxi6wF~XB1B3<7|d9N?#QYmIwH8wqS(?6 zvQHm-S6lo3kps-EVV8z99Npkuc{;AwuLn-F20wZ_Sm`(6HJp(~G7!O`?#BS~$%_}^ z(94|@Md4*H61e_muS?gZ1C#Ya&eyLy&i}JP6B1DKGjeGZH6a+eFhM7V%=~4OdJhc_ z@J1Bw#8c9eCozk_LW0shQoPQ9@E2f7e=m4?DiYshFzjr=0$KjEw^9IeR5VNZkO=~Y zMQcw0aA30BTelLqIm}5YBqj%{v^Y~ikRU;pAn)0WyhyL!bZ{4vVkxGTmN+Ws*I`Fm zMtS<@)avh3zm6A_mnY()4T*95YEx2Df>K|9xS8~##aDVI#wjqJ-nS)8>rPSnE3ypn z7+@_Bml6=2YZgt7js1{bpPwI%1`GD?#MBW$c)rCsq-?cl1B5SQzL@Lh%C1I+PH1a{ znRSmI)7lr=aMaZo*}L3sQMI>T+nBvibElRc7NvXiGN z*YhHrjJJx69!X)_Dk&|+ijEq13E|BD5Cb2hwE%%TCN>sODjgce!*idIuZ!{>hXx#4 zh++%g;0~8B-!(pgymUmX!liK?4Nd;5N<7NC-s_aaDiL$_dT+?oI3#D}G(txiT|u4G zzVlM%`skW>1#0QxEFSd#Vv?5wuy#)uS=h55#fQa}nC$tgsFsJ6N*!MJfD75LX`x$# z2#h#8)Bf79spad}ci^kK|IMqbZ)jNXVCR3N_{#b4`M|L;emOZxK|#6U^ZtcJ=Q>6e zz8E;g{QG5qm@(W@A z#%))r<@Yp2Ah?01tJtHXv`%N-wq({NypA`-ZqknWY!mIVg=~g+#o(-g;Y=I0b=wY# z7c&^|R<7Kb_mqx}t*_7Y7L%N{tT@NbO==Gf!#j?#F9kCGQdO<_Y5erD!gdQA{k)&* zMS)iDntOPk@cA~_YqeVU(J;LHeqQJ|)sE@faCBg@+YVAq8M$sKE~Ghkr9C}o?sMCD zfzA}^3wPNo`Q95Y{d2Fncq0{A6^x?JLXyK zek`Wo)x0pJcK$rITe-q`jRv!Ju6se)E!ys?hY9Dc-?QzUa9r&8`i0h>@0}v+X4CD4 zxsh=hM+7(9{>s&b)BCf%3y_;Hz#bwge&uM^)jqRSU z(VofCI-!n|kKqMJY8=qhQSbe3EwJO@)7K)p-Wn{2>bX7-_*bh7_1E2Z+{dPBv%GKf z@88NBx2=Cq5nV+Q&cw|W&Mhsy0YCJtXt(OS<9xh18|!3OXm@60i1YApr!_kFUnp`Y zn=xk&<=(;}$k`kmvG?B_s;&68qN_(|QI~%syRaY?4BM{#xgEYq=r$wIdArpE3_%f# z!C_|iJYK|=ousz3kNDa$_ooMS@?3Ku4e{jUk)EWuk_&KyM zl>zJpUDxX>?r9VpuZnukx;e{_R!ON!`Kp=VzpIn6@onaNI**HI?ARa?Ut6W2{5|_$ zs92WS)uE7hXu0TR_|V$sJsZQ8u=kl}-KM6CA-BI=SC%&v*>0g` z#(q4d|0-v2aI^o@ZY8QkjkR5#k?=zZOt4TPl+;7zoMPjS>vrs(+H>ai{XvJok(@mf zS*JWV{p;Qk`QXBj)ul^+yGK$=6?ILbZfw1r%gUGY<62e55i3EV8UFmslGnQwPM5@~ z8Z-N7b(!&c7`H1WGY@Z_R$V@-yKlk5R+byleZRq|9_#-#`u8vKC?T~1$;pN85i@6c z6ey?^h~AKB2BM;>kPR?6Huem&CLAjhFRpkWik4>B2Jha69G{sisc4;MNj*515|0ux z27{gML)iR8W4m_GT^6@RY6?79h`%A!dc?dYJ6nPXXK^~=r-+1zB$m)nQDV!|Ca1$X z(5q2L5~HtbiGW9Z9uWI1?K{1?7KyYb;;@AB8NLFj6KK~rHFF5&{(Yyl6AGFu-XkYK z;_q`-Ajfr_jR*g*=pS4IEF1;A-$2#4NEN)|1XT#FZYrIwBklzfbc0$ck;@_uos z`M!LKy|U|^58&#Aa39@n4RmgR^x(xMn#A`16XIlEAlQ0%T?v)r;lD&SpRfU z7QSd9A0=h-P!|sl4yv}w5^)7Ez1@<`=&jBI(otPwQA~J-Vul#43=I!QkKWsJc4NDh z=k7-_R0wL0x_^HMPY1>gJh+F1rViK8o!b{1fxeb8FVEHEJd#RaV-A-avd9uN~3C*+AT)W&#yCWjzQ1@J)ud4U2z zj7RN?Vlp#Hl_n}Ocw-8LU*btSRgk=L8@_sd6K*8`1t}#&-hT9`4&?0t-hN^7hBRER zNPxu6GXOdU4axGLj!G&5n&FCsbD8)LfP08Qk{g(oer|280S(l|EDpS2l7|8E1hLYF zo9H-omy-zWGOEUX=1Lo@4+O%JJ3)BYsyOR?34iuR1K96L_#g{ zAscjY#-j`XH}3T0uLcs0m%Q@#EU{pM5CyEv-ld*7;=&dl9!`RMfOCS4NvN7=-|;#B zLA`srpcB#)d^VDf4Rb(!6s-{c)Q-A5KF5tx7sx)4O8bA32$cg@6aOu;H@ulNbg+Yf z-VF$BE|wQ!;6<)f18!eHfq4*kIU2S;&{8)z*^v$iEM96c;h>!XoF8DxGGHLna>CtD2SF=+aA{}obKzy^9R&8Bc45rIMx2L$&AZ8v;ak5#txf#B;{)j#Ja z`U%Mu+?GiM&lJFL2?M|yXpD?-^N3T(w<3nL)`R%hphf(_mBs*tSorx&J?Q;6SGk5Q z^1pf#mkhWAv2_Qg4IqQGt+@BY1xG^^|I50>@G`z~Sp+g1eojVEN0G4^^EDMT9z0q2 zSfBvoj6m4pop-iTE@ZrNPZg+&ET^oVxeF#c^`yyG&y)2HUR#|?;(Eym51#$a{DHz@e0SscI(I$y*E~U{=fOZeLpnibEuyO84`~Y=<4aO|qvqGRH z6JtHd@W7fT^JLNs6N?Ufdc5E+o_XKJ)?_>W$LQ4Xd(lOa1{BZI-GdTX46zLyDdd>O zcmw?#cG3cD2*=+gM@GKE*$;-yb|5T-ntt^9t<6suL-)ZX4cg*zMwXEnnA zOY#`r0T=*i6CfsCnX4u*HbJ%Oi+T3_^h|MTB|J{gkeq}A2bZ`8TSO=nih%kGi;AS- zra{gIOlt7stkF4mr$q7(A#a;9hLWHslQlu)btEYUHO2v)s%QHKjRHh ztfBGaFlarNZwJP^K0nKPT?q*cXu#3~)G+{wqa9r34!~iIbg=MYAiQ>7w;z-gcq%?Y z0ft@RS~7BLIKf0t%g6`;!#onSra&`0Ia%4}FJCl)4`yaYKYX|=OUqlC4*=NY!|2%9 zx1c?RgodiWnV@$l23m0ko;AtntO3Pc z6;h+sZEgERHt(RmK>?B$A-FmA*ADeN?+DyLE32R&i{g&)a@KD_*BY&-8-C0&OdbvY z^M}Gbta{-a|0s`|j2ya_%q+R=zP}kSh0+!OY&R%G4hJjau$1mxPwUPslg}yI%FFd1 ztNeLCoC!pReGv1~{e7R`c1Q!8Hal>+`(RvjYD+hhhzReeWM=D#HeWmay!$3X1~Nz7 ze(4IWTle#xLnxJ*8D-DH1BRTVL5$a2=xWN-RWH5Jc8%bLtAZ@|@QTCW#;P0(Ibmy` z%L(U%MC<%+jyZ>gw)k^JMPZ?A*={8NIKAZY_!~;Tk@nJivY)HdO}5{zWoT9~4}7&R zvDmARX7f4?#`Ms!dPU1WXS)Y}a?KtUyg?x>%z4kHJTTTyiQ@Fjnnc%Qp9_s;4)Yv7 zUuVdSh)aj_E&H77b!?;cv~Rt)bj6t_*YD{6mG(Xt1_W&1y5wg0{vAH$d!5*R(b_+_ zx9_vQ{_U~9M*601%Ud_uIHjaFe@o4~oCqPax2EP`|32G*kk5||^ph?M?C+%5AgstJ zEV84p&>=R~ZOEB+$^4_kqjW0YlWU^>YEe_7J>6fQN~7SW-~2&ApIKo#rY0q)w9i)P z{&AD^P(un!^DRGGjXdWk5Ap5aul!73UfLrz3P@2x!p-zh-dtTn^9hfizjmnD+gB)` zP8(crX2Z3%5`M+?W$nsbXZ^?0N1lT;vo57Q75Ddi)K5yK%apZVi&y^k?fb>fKO!o< zA0DZRA9;4K;k`_A+q0UFBHJ^bq-D&naMwjF?f0&Yc(-jh_)Axs7)RH$bZcYD)!pAt z^s=BaXP$7xoR591qD%X<|GK$#yh*VP#ix7Ytw0*iKm5GtW3Olj$NE7t3#*)pv4@}d zXJ!`nWNruv$sKdrJ2VspR?VwX`K^z`K6X52YiTikZobd0x#{z1E5_lkRXiD)6^6x^ zrml9qFh6|QbgZlfsfLkByy3y?S1i#S?%u8DGT!#H$&`K9U4C?Z5W+K;X?0$bwOTJB zf@{4!^UGJH8MD}b_b6&ebo~|_V%|_M_2?Y8KP{KrWQf>p+U}%fCrjpmPcuxD_w|jB zay-ctS9Q$W6mu{oRh_v_)otzLnZU_OilB&jE)~O|Er;}OKWXI>o>WgxO=bJ^aKdS< zjIZGMYnFTR*9LgZpUVf%yLAh=f0R3O&B=h%)O)3D#bs}3=Fd5eGtczQpKVGu>i_S~^VNz?|NJG8HU|KSe_(atYm;+RjEZd#%2erBooPG7BpGX4;=aN3hW?TH_1 z$-JidVs73X*T0MhEqhVZhioc*IrLR)$BtVJYAyV98|QOxO|`JDw|f2o;ImQW&`9t; z+2D&!#^CG`?9*B*7(oG;b3+fTH~$!-W#R)p`NN-?ctN}KH`doj?#sep_%zOzmUEwb zdU`~gK?8PFNR0lMhusr36?O$5N%QD|CQ#mB|v0dxXC zhx4rlig}_nN1ri1wS<{1;Z(p~a<)ofWjx?LYtT5Vq&ql!hz}1s2fn%j7845}K=zmf zke;DVWmOs~Ta*GIK%%6V>3UXK9L6vS$X0UC4IDWT{XfHg%CuF!{w$^`D64Hi;AJ(} z`uB!ZxzOOWV`_v77aH21pu8Y-5xF2$h#AoUYaE$dwots-+-d5uec83p7bZz2C}bp) zbl^VlXz(A=lmUCZ4KgH#|8*!LFor?W3-LAv`+_J4iOi4W1>%^kK3^o`_RpG-YoQcd zL)I{;3Q;T;Kae*eo2slkz1qPDqB)gLOfx~$wNva6j1sP22;377v@iZ@E)N!M&m z1qB7t!J*y6JOSzO{O!Xygos}!W+yNYMq)HEGQt)AiP*;VIc&762hDV;FyB;bV%nAjxW^qDE{78;1eDhx=;1>W%Tb(GEOn1Z!-E;ZAG`p0R7n2DM zK^}@rV?!vX7{`Dn<^+RPpA=CxEj2SURAaBw01Dj<+8}n|01BZZLA< zI^ZTn=6vM45cw#gg+-$S>;bFj;`dh*p|viL7BG+G%C)%k=0}0mf#Wa zK7w$?MlhzEAcw1=+FyM)Z1~%MTUSvLyFn|^8Gzo%vGV_TMhKDtB4|bTOXZP!4tSG| z=3L+Hmw-ml_4Wo%$bJ3tG+Lb%lc zKtQwq2+Q;B?z(SzOAtw@ygHB38-{_5#8MJi1EI;l|7!l7q>n(#S(+z@pb!Wc35kdx zlxlJxiJvguGn{Ni4p?I91@HoFM-=T1aLK|RX@mELbdfEv#6mMdhI4P=p8%tb=-$Tu ziGQDU`10V?MZeh@E+X`bQvmig&A$-hCPX6qEQ7Ae$6##_jScYwd(5{IjA_GbS`>&F z9G!;zH98IB_XM(?!2m4=Zx(8sdNeec@Zy*3fvGoOBr=*jcm8~IZi1l0w-a-uSC^>J zT|TI%M>0&I7d17Zh*bi2q8&INNhASTYyy=nSVX(A~8_lQZD;cJ`{z82?EykVE|Fc z0qA9*t9Ml%T57Lx&tBq189y)$xEnJu%#Q3KYre*OywqiP@4_cnT~zA0Rm6E4DYWR+ z(9aFle0?Og0a{BE+CYp>x3VZ^5G2Fay;wE@IF=uC-0+DXsRA~>Wp z_w*QoUJn#!McqJKQ}Zek-7!-}L}Pa?wmlpkcgvzu%KUKF9Oq0nx$^pZXGTH65%b#h zo4y|t3`qga0{$a#2idi-H_3O^9qRO7$^um0c=!m+C9^3bEiDM0Qy?)XvVFG!qzXN_ z?U17#`%KM{Ga7ZlCL3rY?BG|KZzM=BW{NT5FCWSkpYI(VVsqyJCJ+lE>|x~kU>*-! zwr$8E0yqGta%`fxSS`nI2(juyxlaTD_uz|&O{CUZDiUnR@gr~>C-Xei3gDKis9e z*pbmd3i*mkN=6W%^WsU2jZsJd!@ZBEqv3oDIhQaD#^ynE&@Wz$=6~huo70qekSaz^ zrz*=NQ*ir3*=O;@LiRZ2Z(YVX z%6N|VCi%FxH024xEPnSm1AFFZwQh__gg?$JDUfDq?kMgKS9v6jaSIpMaG+}Z_C4iU znN%0M#?G`$NQTNt4NW{1=KK;Iek3H^?CB9XX%l#qdR;bS2fU}LDNaMheWpPDgpOa6 zvXhNW79-yo>n~zIfB4vi!G6^czJmQ_eBIrV z_ntHzxk6X)%+le)O=Y9Z4VISHvO5Z2(xJHEpschy7yaNg|IQ=jS3?5sY@9}iol>rJ z!0kzQ-iWfPxy@mn=e+*av+d#Kp5=eP|E}!}ds$<5Gqk|RFmZD8*3;V5pJM4U zGB>(b?+jh--hJZ0v@_l7g+J?Q?YCY%dDS6H94+17t^s=p3^F|&|1L4i!rQLuQAF~3 z56ib-aW$XN}4x9S%D0atg zo3@c1UOOY7ahumZdPL1kK;)fYjZLIlqvWiZsdWVD*Hb{BrML2t%2+r@nK~s`64`6$OOL=P&p(a#7LKYu{qtd-%xX zaYrAOhX=Uc*tKz5I7C-H{bY##x^W=6RGPEqlDky8Q+Qz^4VQGauK6?GW3zN@ceYLB zxrc@AIa{4FPNk-+Cl#4=SzT^0Zr8aCTE<(-6Vq8+qRLxO1jKy3DBs9w*H zm^#??R(P*W`mP5TEcV7t=<{RAnCdu5ASQ_>D9T|QJNMczD-$>)19mbyltv_u&Bxr# z>>WVidf+-nVBRM@``p%63%gT9g`7`zSsx@c_`45HN;HceoaF9qTqYYO3``I49%G`# z%+j(9E@YS}jg`+4@B80R1zFbTf-ptY=y<88pSFvN{)eR{CGnD27Box*F0CxRK4K15 zuo7q>!Xc}uX6XJNIdkUO#j~!iFP*=Mh>9MXkHXP?XrCZb3@|Ey1@-_$Tz?<;+VT+V zIv|!@>6@;hPeY=11FC#`zjQDKGLEF7rUvhU0qE?Y>u!ckCt3Pf7csipgYgPXa)py& z{A!B|a=FfTIVeC>59&M=Pb8QJ2_?vt*wuxRNFB<**Knx-;2AwzH8O)N=0;d+FBZab zByf|CsE7zrl>M0TICtriJ+9;m?JQ^z!e5SpYz_a>>X8{7N7aaaIf3Z}Eb{uf0WrUyuXj#I>Se4oB#mU#vKrE9H?51q2Pu22a$oc*4)@RS29a49{|M3jK z_(q&N;a#I=z#~GALqn6s9(h8<%aK8S6m+W@9K*#MoN|YU{{leR(SswDRao{ z##~oEs$HS`)XwA)CI?%J_j@ZB4CqI%qTrBP-Dzfq?yYR5?F6P;K&%fTp9xs!9Y{_V%^4QDij*51c4jKM(MSGz}Dk;L_v3=l9}9QA|72>MwxNU-(Tt1HIToG3#= zCWzJu97tfb2)rN|IsoB~3?~D>2l@Bh6T_hThL>d7^F^ zVKt)#qDTvbf{`RF!B>U0EueK5;Cc)%mA#B0=8Ty|vMlpG&3~r^(92^wbqAFhNlMi~ zTm%-p2+wZyx36vZHw{Qpa)^IRx_fu~+|I18->%%18}!WT=+`n92Bn@XQ`n}%7VbDa z+s+icKD8t~+hi?OkI|@sq2V3Ii+)!3NoxHR8bIpBeIyw&aqX{*XQe!-Mi>%6heRyG ztLXSZhI0q!3qZj0U+L6#a^sjYg!GQbmu*C?3323Ou?>W?j_Q{+d(#gLwMc~*?T>ib zW>Ct)BrAfUm%^UK$fF>8o23MFF(N_vV3nEq9@=zIG$aeFbCpIO9v+;%=HA{)1TsLF zdd#c>o^0B)n&#f-=69c~Dnrl9jDG#*Z>`r!cIrI)_Mr}_$JJd^q&zrdCy zADsLP5W!5%L_ot;S^mk43K$VUpK|x4MndK?fsHAT@b-j+o+edde{)FKPF-pHLC%2m zSB=%AJPGoa6?`? z8#63nvIn;tCU6W6OnpFFWg)U{8mKpWoLvHlMl9&jW8ZBM$JQ)kEqi_VOKxe@k6AZh z3KMPzaGiwo(?IaRF?$v8k(%pNubYopD`b8(^4tc&q@=|717@ zYnjYm9D`EEDNCtM12L4@&^{p8nrQuMVbEE9@g5!^GK!z5-3ox9UU|lKfJtfVt>Pog zmx?Ti-3hQWuaOhIwip)LjEtdkTpm7JX^Sb_q!La|G=WgnhNg-EfihsdXS z+d!hLofO%fV57X}`miaPC}7M~xiWfXsqpBV_o05?E-vy|kt1^tV-X63u23W}sJ1v0|{G8jYeY1R6gL$Ow`e}v+D>JrSos4@Q6lm^ty=dqS4K@xAHKjkV z@TaumpG0RZrIrc5yt$T~Sp4(`dYy;uo)eZnl}V1tv3Ay1E28Sx$JJ`9418L*@6=}g z>MeJ8t`zLh)I52I#ki?s^lzE7`9Z04I=fUsQ;+Z5h67(W2?kZ{eSi6lZG{vg_3glV zZI{=8>~409W_~Jo#JJ2D%F8j~HhngGJB6*RxTaP{bpOx8E!{kCf44k2(l&zGL!6rb zZGhU+oNOnRz_IfF(nfpk0)w9Dudl(+hB^mKOol~#HB6k-o)qli)Mgo-8fr>=!YNk2 z_cQ_-zOGBp-q`#tbEkFu-2M1-&n8BpTeK-V#Wxio{=%$@cDhzI~r(ZWTveTpN>gRNTAsLf;%(Sa?{+MR8{A~(9h{VFw|hN3iEka zn`w83v*EpLOZVnq-`!U%c_VHuRQjo_J`W#sVAmnzm8d~Zlrof`5~q#;cF$4_ej2~J zj!J#p6ee?+(p)W-104%j#Ie-hMj=aVgX4ratyvDu=Bcc7j=rbdg+)O*AwB&zGL|RT zMn`|Tym~q48K&;_Q18s$8|h+k*%IY?*?9*=OE}CF9*wAam4)828cKV~m2sttU-l)9 zoqb&A6W`vE_CbfrdaYK^i$jK6vOI=-E=y~89j4%^Ul((ctL*@9lw~{ZuQKaBi z!0BU=G^gHg+Mz4z`4aPJ0NkMEZ|7t@Y-#Cj<~F`^5GRSs;^zoAAbh~C`!dp!$c8{2 z9mudLR|u^?aEAbph_QbXe^5|d2oRxVVpl{0Qs+UGZA46$DmH|s15=CyKLTfj$S}d| z2Q0YXcl7{-ch`L641{*=JFh zo`X6Ppb6MmkmNlAjTz(nBgc>Lm{{{1xp3jaPQ+`lZP@~iF`;r08F73(!ufXAKbpmhC`6*__srv(4?T4CA2R9pz*>R%+2CJI{$E{-~(^d$B!boQDCUk-v{I6 ziPOR9Qdp${H75jj%$va#A%QX2SVL60fq=NX;JgzOddzHeYhM30Tv_7o)h0YUST^0G z`7-4+FC#5#*|^{1pfRA-Bs3_LWPsa|b&a|}1KI(6Zk)v!OfS+bWt(bx0#BCg6#b0i zW(#DEjDi`N;_h1l?yRGtJosvOY5~qQk(HYkdr;wk{sPg3`BBC%+G<-80H*9L$8=>M z2wngn5DX&wV76Ef5ddFT_Sw_yQ^xoOn!~LsB3?KMjEjOCm+O#z^H11KCfXRkz zW+bkW)J{Z+!pt9-KLmPjeqIgO{|j6z+;PIs>Fx#-uk>OBOzvc4Sc7VAd_4~&n#}kr|OEsW|nZAP9joVAd^D# zYb!AB4`7t^2Bm4MX2GJHfT=jj;G7a07uT@By0M^nV)GQjgfwyW04w3@5q=Y#AHU=n z6A*iR4mMy!K}r&*VPS{b^ctWtl-f_HEbUxtG0S+c{cm{l+0O$#i^4zt0te%iwheGD zjEj#y0>eQtqc;Yfep8?c+J0y%_Q~z;FYsK!r~*VG_sg3`d{7dfM8ab}Piik>Qhoz! znl*Exo)JxLBI$s{ObsMGLaJJBtXytnFBiqaEV4j!L8QFM7hbdy_J$?f_YUdZoC?ms z^>!2YK={ctFgJTvI$XRM%QcQ3c$iH*J<@?*FCpO@T}!d+ z8*b;c{C!iD(ZQ<;7qX+`E&y4^W8sJ<_26`m=koCIFd3ns2*<_%FClox4B(AM`%6v# z6{>p}>*op7n`I%&$U=M!UP2X))Zj}@8k{dsW=LeUcq~4&D1c{KfSk(EsX7K{< z0GwH#Jbvtt>qB16OZil?(~I;n{4L=T$_NZ#ry=k^uIlj>?o3a7NTAwe+r~I?`_3V# zM0U~luG5ckIDrlpm^UNufqSTs-R1F~1tleoCIqEz67tA@X263aGrbx%my2F6oGW zujIZA4qXAOry=UgQj*7h%!80rWAm~L11o&whN-kvF@VhVD~HOC=#IZ-bgZB*x6m-9 z4Cj=lg1V`x9eY1Rmx`oF-!cdEc$-btXHNbdM~O_PO3+5bQi15~aIx_uybo<|?){6k zXeWbk%6g0~F&x2k=zWpEC+$Zg=cPD*omE-lg8t%BVHQ*}Zvif2%>un8CWF5!wlLCS z@y1QLWK1BuRQ)G+xOQhx~_aQF8CwwsED0lbhnRtkem% z%cbx1)NTlTuS$ROO-{Mu`}k1I?=j}VjXV4WzVF|?tsu18Oig)aOXb?6;dRx%&LBN0 zw+R)2%x}J>nx&GWeG82taNJp0{Vg23e*T)^pI>@QG^cVB6WUxK(na@72`90SvIUBV zPM+H9BFO2uUx{1ytk#n7w${MnjQAV-qx(~b*wv_Kevf>J89Xtb$EVxm`O>KD=yVzk zXSz4DD!Xt99q98s+M>+FvpRfDP3rTfX7>_qMd5FX%}yzcyo=dA-fJZ-y`FDNC5q%d zR9vO+HTy0}oISl%@S?nzL*Y#PR{x$0ynzfXSG<$i{~e|dNe!TDemf_*cRXc@ZIJS@ zXmap&#S_eaWvlHHwzd){#a_Sic{k@tr@C?5zXDsWy&re~WVuNl9@Mz5-B8+bg=vdY z?h8vVqjTpXnayWTx_Eo>Yzm@^vT5yA6Xayni?*xmmE4oKvif z84qm6pw}KOJX6rQ_>p$f?Ld4$<0dm{YD?aCOKe(08W$IkEit!5bR)zL9}~@=5Vrqh zKbo+ z{`_LMdvR0N;yGOv14x9`8>eo@PQdz90J<~$=73}j@8hp$hO92F_}rfnUERXfNj((D zc}49!r-i_uxh3EykWVf=e|GNG$!tnVX?2g082-ev*x-o~gEKp(1)RSJBwa|DjT=4e zd-|$~H*3F`gd@z23|j5BIHs?E;J8C^=O5j{s^6M#n(12(ny6jd|3F~6YG8`#)hEfh zPtC74FMd8jv;WzA_OVjxawD%wm9Cx2)Ya2Q(`j?1Z}&}CY-3pYEc>F@k#$P=p~NX2 z&E{!fSbu<4Ff%+`KYt6b!O}|h+0TZRy}q5wfyzR&PAQvHioN&~S)5QEJSmsu4qqA0 zn_s!X{}e=y?`Ib;B;OOY*`t-5tx30TD1Uv~`V(6gpY!Yc{tV=<6rp zlD`^HtNJ7Lu@feiD8G9re&N`O zckVia(YZU~Eyh#&ITZfHduo?ybVy0-s1MM#&`bNMM^wD;Jp3WX?aU)3#X|3D$IAts z5?MNZ%0fv)-D!u4m?Xwt9!>S^vrIY57WVF1r=s!T{lltEi(VC!bsAp09`bw+IYZ;# z@d_%IhHp=X&20L%*4RxSxW05#NzSyDe`-SiHSgiwwg>1eruru_4!NxVEnId>KW#jeMZ{OyszL>sx%ej`kO#UYsWqfvSO?#oYr($h%(zk!E7p;fY zxy{@Z$7jz#%@+UuHajJY;@%#I<+TpQsd@2X&+$s%Jx;T`_jhWn&$=j{<)^2n76wv; zgjTXe(U#HcuU{UU$tW)RcK*VP;+ffR?@r{Hd*_;;aLmxsditeY{FR@e_@Rz#hS4z( zA6~!PJ5~3KG2|eNK>Ez5rJ0h^OZ(3LdD--^$+bUuO?KXQP4aY_N|#3Xsdk&CjUJdI z)Cp``^!hj>I#ae#IMk>(dTQMc5uAmFnRmx}_ER@tKDf~Kvf^tE#U#Ms87``d9Nlvq zv63tM=7nn`DzayC_T?<8HpNxALYeMclfTa$2%cwMfRuC2hX1Ak$_Db`+hU}9zQ&roPK5{v= z_NT9Ja$#Vq;;72_=u%n#o!(WAV9Tf-hsTZ!NA!86Wx2G7KIDs_quo2bx-{UJq%tm9 zn>;r=x>0}zi09IZj&8&0ZrP8#J)KJA|9CLyu*e!T!# z%ww&JJ>5-9bXq2hFQki9>Q0({wF+!n`??^z^zU5vg)3uITHAt*!pcXR()JrAx0Qb{ zntS&r?CPz!mLo4tRVE=|_j5=aP3f~bj;eXe={ppt)A6;?8}_J7{~e#PgVBAv{6D;0 zf;TDSCSPfn)I1R?e{j-LLLyoEX;~BeVK{G9eM(eJGN*OEQL;XKrlUhIev;Mw*MZT6L)y>q*RPOU%zzLEw0?_K=&R{q~VelzHdvmIl7vXM>L zIcKAwD$TlAuRmIzBww@}H`^KBoQbMjQl0-^Xt6}S!}G3+(*^705rL=91EWrUip+RD zk?&J){^u9|p3*a-{{LwD4sa~r{(Xd`g^*-tC%ZyrWoIj^>=6wutL(j!O-pt{k_iMv*9lkU|On^ZLI3-*NQ5$G47$=Xvh?y07c=IX~x!q|Q>cO_$cKQj zp`MOy&NELLV+I@ayr|cmmMc|e4|IL9gC#WLQ&YLt@_gsRx;T2(A@{76-?}NOUhbJfzqegnS1V-$-5jpH|6LR0mb0OAs%rbx zCk4?*X9Hdu`*mKHu8zI@p?08Tdb0M_Vy>UYhwEOGT}{{i)EZj`WQ=l%{ys81GT`hv znN)ARZP~$~#qH8@g%?K`hE`Yje!AZF$za}g@VDq{)*Y+K^3ydpwPU8LUsAf}J6ViZ zT`Mefv7A|-=z5ba+>Z^RxygqftgcFEWEt9myjhN?Q+ zrRMw@?KzVQ__w;DxWdAk{^a&CS2s6WTgWx;eHZfS`{0qTnx(9m0q(u8V#cMCIAetu zil?4Uq1gS?c(!+&ZkO5z4_Uo|>0Xcg<(uzh&pw$h6N@>s??|b6qb&5xTc6uLZ1*q} zV=cY^X=S|3=ol2_<(g+dndn^N30>;ld&(kJ=83Tm>QcpwlA*cfr~WB-tnz$3m0BSR zTlBVGplK~@9-AMXzl?Y4l0{ox6DgU8Q52VdGVf|0bm;`8rJsZV}9(P&-4u92Aptx$8Z z_cP|-tFAd&v{`%{dbIh)_~Jp)>4cyw`Qx=2n_kpU{#F?KBV^^n{XzVNBt8r?Jq@$^ zg&*ZYCySC3RkI{NuXz3UZ0w*@b8r|b7J5Q+N-OJouw+)0_3@?21@ZpBN*0#f?R>$9 zB(lS4_;Qt%Bq%$nW$dP_n?-{7gjIt2Gc!{}=qNQb?9c3z8W3n0YFZRyFwQ(mS7AxB zS6I@R)w#S;>w2Uzx5O6K%n;L(OjFT&yfj6&b`?5r-U=GqE!r14Ju!)*Gpz4bq*z%K zW^$zLzNai=QlETXq?J|MYhIi`CexedRNkk3MSAuuUjw6q$|z15X1?E0@ih8fz1H#T z2lzG0!^U$K3cFQQGPfALdp#|+zx?i@Z^y;>FWPQ$Woiy#-!8nD{?5go)zWu%c5aU) z%bA9!M|`X%wH~mz^w3A!s2E(o{+89$vVmf)Hl9{k@ZM&6x`ztfAPO(Em&{phMy3JUj6G>w%#Wxy(9<$ za+3j)1)Sol6PWn|aY!0O0CU|Y1{c6p+!pgvF=csC6Aw8MGf3_rqXTqx=-V5~O?Yp* zFx6_7*)Rs?7>~}?=P%JzlAbnXl#F|ZfxG2|i)T2+1Y8o+5RhPa& z#=@f)TeOzZMZ5qBa^ZkrboRUcejU*i_IJNA76KH3X>^i)K4(b#58(~(z{z}x z@Z63L6~2V=aw{N~n2m#_OQdB0o7@%N!cDcF%-VSv&WASHmo@{F0Ib4YeGQK^=14mY zDc_IH!~Oz=1(>z=Y#JG-h0|JjLy7Sf-U<>n2>i*=7ZVW^BcqNSt5hdMO1wK)7s`Mq z)N`sg$jJg{CE8j!aLHO6rNykOk}(rrai&f%PJHW(zJd5v0GYz>(|0Bn>1cE4!A-2J zYK{jzH5W5BHPyg%1rY$E5EM3h{qy`JiFyf)D|`$51iT^!1J?XDz2IS{|Lv2}S!o2L zUP(6Wh`0?;hc>H`bOh2ugEvj|wl=Fj(+~=~aO##F8zJAJlLi8;C&}i1$7{D)@hB8C z-%A%4Q?cWs_m13kEfYYdxmfsJSKM_F1R7=94<1%X|M5sE`XwLYyZJb_Bw@csnYFhB zt2VA_`QYK|fV&XPwY9bRDaO%<1bJ$)Kh=A0 zP1vFMFGp5QEs)|yM1v3Qw83a|jWbWx+Lp)4!ojeKQ8jty;z1wHeZQiy*R^9o>)O=3 zx|-9I@Ee);f}5ih8%%uf9g99|E>-W6GHNZ+`B*mQSH}~h7vI^_E>i}tzak}*$eH$N z*Up8q89zJ9bcUPXczr{+8pMY8h`aj89&t+_eA2r|K)BYOlygBs`gl46kwHuvK;mPir1ld=|^DL!C=w%1S@!-p|!$bVk~Ty-SxN zHS^z9ilB8hF2|^{g8IZo+@xJ??g%YvbXX^?3!I8~kxKblzOoeV^E~v;C7OP&)`iPk z`?e(fQRFl6^ZtHpKubi$Q-}VRuJP6}ku&eSHr!wCmV5YEo2sj(@!Id<2VD<;CU&or zm!4tuDLJ%}A$VL{u_4K?jVswoG&&$O466iqhk(=|r4ZofE)~|~5o+*-g@L^=Xruf$dG;IuRGc5a)5+jyCS-Or< z1Gp0({ghtAFhOqAqDuF4Rh0;W7~xHdi@+smz?e}Vwx4}vz!!x|$WJPg_S zD-gF0L{q>(l3S_#621HS{&QT`=Y|G3Zr&xaW^UEfk11m_8CB0iTPY3WS14}-X5sQj zHLi5lTATXdx9bi|(E3KX-Xs(g!V9R$b^!`4_x)au`9oVul1#!HlT|4C$M+6#X92I` z68IzQJgFOj*FHrKTl##F5?p}5;%pVKeiKF2yIrhRRi_2ONE;H2IBhHBJ7VE20q`7b6{fFvKzsw;m<}O)n zex`q)*DKkA62#B6f%py#LI!LdUhAvaevK!c=Kr_=SVhOkF9K>Gsm6z$7$X}J=mJr) z1hIARaH;sm)uN35F4~e)d$;CpY=}-4=_J=`ZafINy&AGO5PD zb-oO~Mw+^YxUdYS2ay(lNNzL^%R!^tSfc*131d)R*l%d=2Opc-^=<-W7_7Hc8}5^HrTW%v0I^~ zWS^%xny%wDt3id#H;@6W*FY}OxW(Y)?)>JoxFOJjf2lc z8+N-?9tt_SIJnc&Ss?51zea_>?w{Pk7s^7NjIbTrhLQ2-J#$9M84g~z*fTRtIQb9?b5uD zf%|e?MZZ9T#GitLhe9m{44=`PR#QtvTF!1VHOisivnFw5*v~H`8y!XO2B(v>k!|xM zI}7{meTD8xv0NQI-L9kc}@?(6@n2fquQjTT4tCBr3yw1cGF~RpzsMNP< zb?NUHTt7g&@!TVwpLONT+kDPZCn%sIV9Afzz;Qn-l)9m8P5jyCvxalA$QWa`{L$8% za3W)C?63JSC(YM1LZ_$&U9^rc{jJyZw=(-1-2P5pZ;7UFuln*qS(V4Czgg_~de&_( zOKJVe1|=}|9lE|V?-|o|vOA#>_iK|I*#zkmom-;`{zB1C}e2M%knAT0smxJ*~T=}?t-h!L6 zZ4<53CWn&0X}W^)f5&r(C)^EbC!MW9ZqO2u5f36XJNo;DKEruP)@88G6T>zXxL5qdZ@s+;Ylcg3 zZG$o3Rjvej7`Krk|L_+1P-UKa1OK1fAS1f!WI@Ka&>&%R{<-?k39~cA!u9L!SnuzjRzoPj^{SM2EdVInSv>_)hykw~f9XGCc1J>JX z*Zg$CaeyD1cS&<_;@lS)?u>#>=YS*XP`}ifUdg-q>90{lr%U}Cw-#omC&NWp##|Up z&`fL}s{~_AY45rVA~n`FKHZX=Ag)+?ZC2ZWB0xtxji}o4a&p%3Qit2k;*eXsr7)R9 zLNdgk&8q6le`O|wL9pXUeqB{DXpSU77!^!f=_wquzOZ%}g>Xl@)ch(`Gmk$!s@5z8 z2`!c~H>#cLhxEw`9#A-$F$+y zO-qznEoWMl-%^O2VWr+OYq*-u5Y%Gz<*9aiKR5Tvifix_jjX9jm4=9A!`)SN!eT?fygC4qPr(LlPh6k^&kebWgLk9XQ(f+5ZwYK?Qi2WJnUvY&- z#;)V1R`{DDn(lbjDHo+-3+sRDPpLUCy?es8ps+UnXwz};m(;Hx#yW8Sx!V1FzwR#o z%klO6TTeYL5~S&+Z7-Hub*Xz4tlB5yU!fD!A*TItNZa!)<=#w&KZzN=bBe+1PmRme zFe`R1_6=|RXky-6_n>Mq3&uUxHyS6Z%H)XicU!>rV+ zF7=LVXZ6gDo2q&|p06L*D0Ox5c+9Jc%^$~^WbHCs{oW?@cs%$qJz16C9dx^5F>Sk= z`7iIP7QQ8oQ{JbJ<_a=0{Tn&f9YL*eck$-={*T9m)NTe|<9s|ieez#b|5$f`NNjaa zPf1)yc%Ji_m@;Ud3zG6U6fK<5>|u%wA}U)CGY2rAi03(6X+D z(RWSh8?fYu)z3_B7;0)5G08LTUb*~qXSj-Z+OD)#;A3>uRDM_Z1KjU~GFJTyCht>|n%uP5rv7U$Q1OOm3zI7OqsIrUYvxwu3Xz@+ zUOv`mQ;o9=MQv8Tn?cnfI^llv+O=znUB!TlKw(K|77J3&G%I`na!B#hFZ2YhDm5R4 zEB!{o0av9XhB7-&-X_Td&<)leMtg@7i9fBi6c!`I+ypE>vMVk%DUs+-61#EY%0wRV z6egW4F*s_gni!1eRQxeM&XcacZA(Oy8sDLQGo5r#qNj$U6R6d5BACN22`blTGzVYc zc{D=)N3sl|9KQG}tE`L(nM*`Y16|$6Ryt692(H7iWozWPX(9Ld=QP_HVl^ykuIzgA=P!q>;ltX_ilJuyE18Ij?`;L1;ZJ)5q@|8(N+Ex9j|yWV-E6B#dBu1+|( z12ua(ns|gw(N_-CV=icUb{XpceNQ`aal?O*KK1#REM`Okjw{FcqO;bJ$rlbt?PyYg zv=KERdS+`s9l@mAm)hjvJG=dbq&F52@nXU|p-w4<2Mv+*LJRx`y!&Pl^(QJM*kj9p zV#3;s3wT$z!^%cpXbnCSHUQfuh4z?Ezat_|I_4jjPeIG0}`CcFV!79TkVocyJ*McQ3{#GZJHsdT9?oSjE^YnBbgT1s1D(kn1EmZnBK@}HH(2X_)rUf0fq4Eq|KTFVcw%Y zIn-b6)XFMeIr^VN8J;&dGB^YC;I3W(3<(I`T_N_aE3lKMdlUanE~Vd1t-dvkv}!o5 z)XuwM^!?=Md#$7D&y_4+O6fiao6f{tqdPw14L5C&y<_sc{cJ zXAg3C=B{{0bHi@D-i3uq#x>#J%%(~+k4Tkar|PE~J%8*fuJfF~r_p3H!AQgB5fyMX z;xqqN9p!lcY~|rY>x>3%W)vxDU3x@WPDY$L953EdcQ&&mEp+Vi-x3pwEbWFg zenmk{$7c7rL}u2XNtF#K*=DuBUP8Fmf$w(cYdj{bC-odb8-15@>Vp zv~~QwC3ZB#P22OnEZYgI@PO=auY1wP;*y?P%vlV#f2cfQ9zU=raV<$;evv9$H{Mr1 zgq=0?q|XMfDFK;cEw=}5F;4}=7${Xnt{n?=P zS98_JOcb6UKIk3(Yu-UoVve(LxWFtyUHf~~!3TA1Jg2C~1z+yhO}?(`a&!I6`Dz2x z@+pxmUOiPC(!Z&178y0&F7nKor#e3;k>|;g!aD9(O4>BM!e?Zg>rOGL^c4{1s7yW$rS=UGMagzKCvnpvo&f5OXq;v%ucdd^tH)i}krlzR~52bdNSg z&)ljQOnTA1=Rlrez){hPT_GXik1>wfC z?&fTIf_D^HX0q5%OOl%ho;@z=j`i>+}XQP2^CW^Y}Prg0=AbW}CN_o^HlP{cEUv}gMYGwww7@s)J zb4oo9EiL`26VK0-OFlNg*CRR@!mOrj`&j$uZjXEUYqCfk^`{r9)KSkhK>GnWQvuaQ zA21}P<-T@r?4dqUqBstTl^!-#(*D5dk+?Q`11t3LuGSHW(#QdwUXJgPlLhA3>9AxZ z-Z$hVg_{?UFe0fWa}y}RFdgD-f1}3t915^@xHgQ#84?z4gt_j=Y9Su=8;sAd*B!3m za7;v_40AMT=^MHdaau4ecW}rtnY8t?|E7HP5(JSL+-%dA3rg6fX@64?UD9X#N+I@G z1yzdz!;2ugVGe@%(t(E0K$PLISh$NG)91t%$GCsXd6!Qih7=t?iX0eyz9CpExH3M> z-V!|~!YsGuF1`IIo4;eqr2?glnDn>TfT>Ppt2wXaja~Y$r zORImzhzUNBWk9h!C`dRAcSSl`U}7h9OiwiBi_g$97sOySf0<%6E~8=A&hM%TJ^ezG zfaRz$&i{St{PreS14dUnQ4oQMDvcoNqsNFz0`XwLuuma-7$`R}Tq7Tt0&V{Y6hnkA zeNzyB$-$|HxaZsCgTpSvt!~8=&A1%!3DPrCkOmsB+QqSD9zTUi-E-L3$xl6eAqNli z%g{DmiPda*>Ooq06hQbInCU4!w|%!G2@XeOEuaanDxSx&aRE7(n3AHr9Yf*6&lM+g z>6@;QXtqy>OCWKGgsWb;0cAy5*~NOkgt9&{n?JA3Su#jmJEn#9xZVDk66Up7A6d}HY#LpIfzFWxbmax{riF7Ra zvpGz4<^2g^9OhlFq%Z{_p)Ht1eC1KiVqUyZHIpZEzaj%#au(vbfkk^5g~Zp_DdSab zR$iZ866t`F|3P9Lc>7X~p|GUqlzE%dnlBkA1D_Ay>Gm0*MmAhxn*0)5_{r>Y2TTR= zSQ1d0yb)j_;=TfvDaixSUZkCIe9tGz8kpdt1l>60{oSFa6JM2jzmG?)A?no|Km-5um*N9=tn=H!Hj{ z>FTj~HoeQ|t-@9TL)(3~oUoA)^ZW~UA_Tx#ftsOa(=LEf*lhrh6WJ|>M)(Vs0C4}1 zc`;c%A~M*bFwR3zLimV4*?hQ57SA`PxcqFJesZtPaTD7RI35v`$g%YNWaC)`e{2C@ zxGSUmQtO^4v$&-TS0wro;suwMz1Y6;@L@^icme`YT-veQ?>*8{J4fI<#uKUmI06U4 zC9FPhuJaUTCN3@`u<9)O)TL2ztGO0X70&cM+FVK{ zRy0)bvWm_x!I+OX`HuEheile<*sXA}cI#fjS?~$(HhcP~cTn&@_7yeN36F zL@i7i^hU)~FClfvk&%&REtj&bMWmX(u87_=W?DrehiP~s+VdXk=$+k62cAdFdaCcZ z-g~;qA^(U>k|(d1QI~*t>c(h^@8(rB(ZOG43OASScRei@#k7s?h%IKbLz0=a1LnLm zYCKQ$g(TGvFIB(yF#5vzV0hcIt0w+(8eekQ!Z_2nK|naZif2(H;jz+K^EpX6-9bH3_G7lJ{FNTz zoV*XqBJXYU9H9qckq0YohlNy(sOOS`0k!Os73LV-Rp6QEXEbx{zBM(C zDZBK!QP-7YtR58_vq!2ed{`sn?epKY1|~@IcJKX0CnokQm$D(|3!|`qP)F@nN*<K6jDQ z9x_sP9k(o@47wrR^30rH_U%4l*GHm1^=eN3tv@ssGaJZi7B7hEPb}`y?kml^-#yvk zzIrGwp<7ti;?PxYN1lthY0W2;21Qx+-sjmY*(um(7G3e_M$^v$4XPVS)OY_xX0Sc{ zkV(zUyW*SB5TW2Szs~#n^l2Ue&UO};RUe@slY>XocRZhxbhWt2J74vqyDdR|Hn!`Y zIYWa97`TYffgX~ZcdqUNBLsq8;7QMW~XY>Q#iAhxU zC}_Oq^S-{`;HbI_)ozIvE}Puxx=Z7n-(A^|plVC-?4z+(D5TAd50%ExdZ-P5U=}pQPXxeYYg$Jg+Em~P75asjOvKATwh^=AIIJ3 ziCyXX^H4CCVNy#rb{4Y}L4}2+qZTuN7zpn*^!6gGG5kr|cYJ?p7uT&`fxll|a6OWT z$nL>$p^;Hvr3VBXG*WHEd4bsa{t_vkIXBZ_tPn}Y5Tw5akp-IyU4gf4w9*5#~mhCLcT@mb)-7>Z%1efN-R5hL8sGGi70;#Du*RHJ%%0?_BfA zI(c&AnITW7%33)n>86+kR-(=dV4#5G;~TMlL;tGs!Go9}!(vY^^YUzYpBONG%oXG* z4kF6*zD-@N&2xE;ObVP>@#Tfy2z{* zRnetQVIU}+>3tYh5AT0T_L5sUGbK=*ejAZ3sj`Pwd0} zO>#4os6-Lebh64M`ky6A7o3;mm*IRr^Mq>j4v0eXC=TRtLr~KWW8_ae&f;XEv+*cY zd(t=DJJ+YEor0GM-mVqyd2fJ8k=NdyNm+2(cPr0aOkw#yqNnZF!yYyX5+q9rQz8-; zZ!V^R!y142=sW+Gcy{>I%;Y5tM}LX@1uB@k!Zcf47z1%gw;~e%rIl5#p-(6xV67lv z^S3jN;wfAb%&XiBE?;hm z%98C_{*McgR{Kk!7G4L4U3`gh^99flZIoq@#QF3r*5X1EnH$*%Foi;?%AiX84N&(D zTe~)s@g9`^ujR*XkEki8)?VsR-+lq@5MSu|&d)0ds``exVjhYTuxAe+UPl}U@j<-q z#GvFNb>|mCXMvJ2ue7v7yTk+Bjf)uR=M)w?*xm$zf{dNEt|w89zq~b4`u*NWH4zjL zY>@z1wn3{reQfA;nV5Ykn_@5d-zie2jfifTp7M{DlvL96ErzdxrRA1f({ zn3zy!(B1tSH*rGm1nYbwQ`2wvE18IIGnTqL2MZbG;Ud)d zXM_nvM3MjnnvJaZ7Qn<%)uJ~bzR-1W?89vIQY_DZ9Q1Xs15INFY%T^Hsu=7zQ&-JN zv2u7@mHeDG{~dZ7%IiBegcZ+RRw*()Gc&qX_rANC%MVs<@7j!?$_Bdg-?7xeLndZ?&fejuCY zp_Cd9;jG@p=6^X_r9J%fT9ijps4L zn&eq=tR{GkVr%7C)sBnQk)kT!oH!3Tz5C=e79&4|j6iVfm#MG=p^kwn$oW3zV;Yk6xzL0?1&Lz8F>~j%i3^J^Lqn*s-@<=o* zO0lyw=3eZAc_iPGrpwZ+!h25iFZZyjCOs}KU1vYGK9ZYTIf~A)qxcaB-{(97(he-^94e1!K;k!-dO z&!z*3FAu1y$L^-v%p72*BgJwwj;H^#=EhU9Yq{&XK0Xo7_sdHD`+dFA5+$L-snQYR zlUvoqT`nclykvcq(W2t2*K?xc%=GA3f#vSN1OP*0SWq3P{OC~12HGhmDsSc;e|Uca!>89pX@`x&gcNbz(Y;P&VF zKNwrsmulgVCY%D~E1<**SRD%>Au-}<4*!w1`XKo6DymggLD%*_y?hn)s zmoSr&yYvHrBcN-0)@{IvPJi-bHD29DJ$jaVRoBtCh#u(Uu_0i+ITaBRs|85i_ zl+k)F48;pqrrDuuKhD0BjdfShlfHrDvnlcz zPFtX8wwjrNdvwo%;|;7MHAcLu`Oo6>koAj9dGIe->WK4KJIIx9E{1@Zn|Gx!%|REA zQ;ZZvxW>3_|3WZxaJTZ}i_ux!9aT?(5n&rUyR>8ErFN8Hl;_{6tS8YCKo?$Mi1W0` zp;ic$DeAq0P4jCw`<-Dzd2R8Eh@VZAam9y{fR)pPK8+)jjMT_@kpib1Y@snai_Xq0 z6a~G;!sbr6)GBK3E1n4R?f8%-%uiqaOv$gc5bo4n zSEyXDzQESL-|iCOeURALUdcUAz5GUa2UjQq2pdiJEHFuG8X{wiD;*K>LEq&gAy zDA*k+1yE08pnAL8INT2BGiJ)b8=H6zoEtm{6r^2oSq$HeVw~|2lr}&_iOVdkaSu!g z>UHIU;R!Q`n?R{H9N@uzLMZ z-P5;ob+4$&hT~eV>+ItL`s%;;3b41{=a+k_xkPv1x9Y-K&eKKdN36VQbZ+c9F{!<+ z$Ryi5IC_oF$marEyY0(s$_$=YQj8hh9b(ko@>J#yU;03IBh#BXd*eS40^aCrkC~{e zi3D2&^D#GE@&$=-{D7VU*KXU#mTJy_3szLF_FkYIxb`xNkgj_Ou_8O>b`^T#MkY86+K5+1Bz0%vm zqiWSp-R*LMY6sPps8b*AGNGT@;lil0;=%YOEj;q}@UGPd3@O{gJ6<27pVTc|dzW?Q z+#W4kC;rOqn|_wCYx$)m1dD|l%rf2_e*QgqbVx>W+k%>i-EjBij{R&}%5Uili|I}L zRM;N8+83D>$YgrEpDxdbp5^MkRxX37xiW>^J@-5(qvJKD@Y&n9Up-1C zcEtV$lYkiA)p>^7_Zj8;4BmRZSlW8D!^_;!dUx+cUe%BZWs!;f-MlVmi}>s5H7q`i zfvMG9?*H~i{GxguckZ?!Z$@{^vonF>B1*3erXMl#om`6a+oe4HmXRvy(D-p(X^HKn zKbH+Z*HQ$%IM~R#WJDE6KYZA4lTT@If`pmmXpmgW(58_BQ_c$<4^3K+6{yAi&IzRn z`%OC$@$()_@KS+F;o2@$m)l++mUyP^XWVx%Z~4_rD-*@p7r)K3sx-3xjrsKl7iV$53(VJlH*4c`bY+Mx0q29}Dg6tc7SXwBKfa92vs{^_MZR!HulKEW!1{5;}&N0k(Y;mCvYdMAT=of7M*i zpfb=~@jG#g;=!_5PBSlK?^kbbM5M0-9%sr%WIFK1%=RLrX`;m7S%* zQdTmV`=cdc$s)D;)8(J%#vdop9WdsXYzr~X&wO&-aZn`Q=XjHr_OrvXm~mCtYP)i# zd^4gH;A^fs`_9)vXh1^0>xw3;!!2q(<9$PMCv($RkGvMo5oQuNv-4C^U2Ey|mtSiq zXKG&D7oJ&M`+ynx%I!k~JGG80D`fNZ3J1%U^61I3UTw_mVAEr8h!HSS{KI}c`(CLn ziMe_26tBz;-OdslrQv`& z7?~H>{yhbe5a}dg$k#A3cfQXlO$upln0Q^N{zPO}#L2|7nv82GKwjS^Ba?)w4~Zal ztNr85b$OH#^K}ZI#%mQ7vi(oawG#jwLk5I(k<8uma#s-s0lqj>Ny+{Dh%7w!EqZsc zvJo?J3dGCf+Q&waCXtwzLdAU0{xPVC{y2Nym`%!cauHX&uI1-TC-+Y144@lEjU&w-=*nf{6>K;AM= zubDCDP=Sw~4s6p2BeCYnZ$3Cful>nb%YrYt7JozR{y{AyppN%H+I{HCug@g62aZ>K z)uSKkR`){mA+MjG1mEY2PdNvqX!LMsfy0=c==DEg1QlPiw9Wi9kqTw;6Re`88m7*Oz*(G8^y6au43i4o2RM& zT-V|c6Ulab2xN|MMiJDZBFV;|Qsy>WRh*f+@hlV`7eV96oUFPhe9py%VKNJPyx`0y)f*1DkMpSAz)*Kx z>bPCf1E5ikDDiNl6Sdwv#wZv=0$s(k_Df#e`;)95uOf-6L`!=M>qG`4NrG#(v?{uVCKU*43idznM7VH4 z3RD|ZG$`)sF;hNY{Ru{E1nRImi`gqXaSVQDJtJbZv*4(X?<)A!J0xu>Y&YaRULxdX zeE6GR4SL6<@J0SE&KLkSti+@dstxia5mFg>W4BLL9S@H0|GO~V3NX&4Bbz1yDSbp_ z17U`;VCie!4RG0RNSO_CU+Gj+koV;X~4T=cChR zR~DvP;CtKq==642m$}E01W-fa#ww4=|7h$;XD90bPct#?x1SbB(nj zTNqv$nwknO1Ofnz<2pnvc!9~187@{n?5tw$l5mSlFy!ybJF>l=$sWs2P{B|M=Hzob z7Py?fZqbCOkgC_pD*fp>u%L`#ONM6wmJ-md#K?;yV-fFZZ-onT`2Hf^P7?dJ-E%tA z;d;IOpBjM$lK^Q5vOZpH-Glv=ct!m@O3-{VEVZ)wv|+KPOD^_*e@vdY_f!6?qo1I5 z3}~QtYWthFy=8sDVcN|_+t@@}?KKy4wDexi&yRU-s*bzzSz}X=gob*BR|BB zm;9z;y#Iad?k)vmvo>uRJ+jhiQ*(%%vqTYZK`2cPy5N0F<0!O9q!St&{_A7wS#TVGDjmiXZ3l+ z$fl?_r)yZ#c83()%GKHYUGbX5Wl?#dgV9MLjTav4JAb6FZNEOQeo@svLf_GfX_u4J zo8~(?SxuIe*NWb$xIBya%(Gpifxp(`@SKj^@!i7KVyz{PNm8xJ_l|R^x6Ni#JMA8D zaOIG(rKdX1amS46z}ifc{|k`w%F8}%ACGKH?CR}5B9Xo_Z< zv32i>wbK!Ie4^$Xr{V5E-{tp=l{Yl>K;w*Q}{%~DZ zOq*d<+HP7azro-J&&%flr5pE}Jg)L)$~2o`GcxQH;0R;N+Wk>Z+EMh*#^Vd>(${}~ zu6q?$6oxhEYM#n(ad+fY$(NZ)wX^y0h22psLQ+nWlG;zr~FsHlYch+ar>FI>{xr zPvkD`55C>=0v>AlKVqmoxEtdy@>a=2E!P=aj5D`sb(6b7a9J-E3<0#KYdN<>cn=NM_`qrLFcN9W>H|Rz15d*q%DEOfg zE7EF6!t)V{nq(+FnPOfp;Wr4pn#pfyJ8}aF%osufy{cXMvXR2_iQ8Lnj>BNO{`Tfx zpj4{|;>s~*faQS{h(4f{;Mx<1=-j z7!`_t%K^cc*g@ll!t6c@@uQ%}d;?P@6SV>;Ipv7k!S7YfFreFnf^TZ-44GfNisOzO znjzX0SkU1(({^_Lb|GaDoFd|e4_Pl*u>VEbY?pBE$X5R~Gi2k=2Wro-yFy&65=Lbw zJo}u8EDADu)BEPnp|-u6EsWuaw2TZ*rNa5o%7pg@{w)v{$hM4;J%j^Oi=T;Vv|y>=kMDLKTajR~yppd9x;Sp0aA~eIf(X%5(WX~%iBaI%9aLu|QOKfHSM!MYV+fH;_{i+6%V=Q>nd(=7T|Q7chRG=-4%QKF$s z!o+kP!3t5PWg@%ht#ut{!X$_^AqG}5r1uFsWG#Io7mzb9Zij$8x-!DN`%$;%PYf^k zne1i;YcMcH#qtVe{cF;$9}qiPG)y;*4;&T5sUhRHP%(*vlmYh=D+r25&x;%)G7Acn)&z-HXKbxJ})N7&$h-d4v2fFpb7=CH|mf zybQ3J@Hqec`9pwU65a$VJ?F7wEiEm)DLPTvQFouCwSX0H3v>V-NJl1TD6awxb*dJSHVfM(FqLXbVSGe>hP?|U-I zMmSh9`sCVwKaZdl#-0SW0=oL88g=UY3+(^fb!Yk5(-F>I$X5^&R3Lp397>7`rTAJat3o(W9063FVPH8m^|^H8anV#ddCOv3mFqjcG<^i zYhLGmnV^fJpRPXa9w2yKL2e_5qWiesyKe z(@Xg^|1)6~M-`^2Etc!LS?OzoM9YTgPrWblxEdNYm1@}Olq|cUU4D!GP_Dn}GMkKD z`sr8Z5^aymQXZ7m7xJWiRPIzz^0{MmQi1O3#o>vs6`^5YH+=sl*!;91r09YD*Bi3a zKeIaSo^pU@flZ)`UQIvc`_Rs2zt>~CwAm7)9D^C9lpp_8{&=!r>9uS_QKIZtO3IH6 zcesUW4>^S$OWQI2`&C@~eYs)EpraduY~&s=xW3?v{2dvUYu>l1Z0AQ#&ZpJ#k!SVN z`#-X0GB-axUZ_l=ZpK3Ws@nhhnSC*R9d$H?*d!?GwD;_CPGEUS6NS z4i{Me#jOjuzA@&K9?I*xHN2#YtOk}lqos{DDn_{&7D@o&q}qrHLGqn1^B zzSBnsTc7;=J=-!}`;RThzK-vkZkMG@@cxOje5@_48FwsEvT$>KLAKHVaRI(7#%>!* zDZ6;?nnjs0gME>S%eu|ozB|lLt{+y{-t?*dlj+L@5r^*Ujh?yJ!do;2ZdYz;UARBk zcr~QDIyyJ-#>Oj;f*AMm?&NUY^2S`%_r8a3qCjplQ;VO1M}|cud)e8Vo?kZxDFlwZ z>giy5!sH|S>yV&Etj5pGpnl_5A6nFpiT3+)yG<$RPW5QJy}x$HM_et&gj#gd?TJ)@ z8#{*YJ<6giI9%mXd)ah*cJkJaX^}($+B;%2nm4bGo~!6vmgu$JE8s9cKTndtN}m11 z0BCP{6CyP$&JM74$rcnBlPqWKa!^cR)bkZ87(S*b20H4SG?xbkyY+^OfJOnlQl*aH zA7*t-M4Jg)DtJ5VWasxhImG0y~#ZvH) z5@O(2tA%~%yh1u^0pkMZC^af|BRblUx330@Ri7xqn&UPkOCd;hSO~RG;MdOnSy!+( zrfUM!QjYDC#ghsTEFUr5V8IBjfn7J%$@0;um)Il1en%Y|L-laK3yu~s^NIXX2yO?8 z^%^~Sor!t8lSS{B+kOjun2bJl`=o=dw>F?bp9y^=nFso>{KfW42;~@9)sM{49bp?N zo2-sB1@=NgT$j<_5+{x*HPttla8lxA3!XjbJ<9`oBi50lR#tmL3xOX1F$2m4TX$*O zdJ6QHP2h{XAYq1}pCX(DbeRY=`*(RoJgC6B9sv0Dy2Ao(##gc`Cfe?M&Cpt+G1rcpy z0Zih`<4J$Df zPSAie_yAep%M-ygIk7#fF>yu$B^X$^>2PrFe0=se(GcNgU1^}0b(Gddx4uK}JP*o$ zl%97`L2Qp}+ErXhhCR33s-kQqmqHrrUJQQ*=r;oG0#o{FRK9nCrh#!na+^y1O$w6P zHp=y9NrR0qXmCdt(1SOKLd3ZnP^C4JT&8$40@)eeBo->KtU#4&&TsPW2mF zZ;>jGr9ntQxOL=OQ?8?hkut6&@!D`@2csJKw5SX#(T*SiGmI?>5denI{LvFYZkJZ$ zvlzYn18~#$z!QnOzda>vPzIoUCnQ3w)DB?Vgl$*lKGA?98&`1?hn(S5D~vVyQgrNY z3UVsAGq9clwTEOA0X+g|Z5*W#x)Xp>6?QP;BE1$sUqV2Wc6pcStol33*6X)!hGIHw zRwDo1!2dt=oVH9bN(8ov$-IyoGEqiCS(a;II(%zdCORASMqG8tM3p|afH3dNh{=?e zmPXdFYN`{y7+GhwEIV@Z^Pvt!6$v=puA{eGnn`YWMFzvYQU^_7muhgGR|?|NIq+YaCJCis^-3 zeEbXFO|MJ*Keqlmp3DCKAI6bnXQf5iqf`hPm64q-B_XRqN@j&*RitE;jBFtxdnThr z%9fA_A+wTE!gYV1pYP{){jTeF`Qv=QZ|7OOUN6USJRgtyjNUT5%XM&iMs=Z4N=!4{ zG_L-={jbBLZHq_t=MFiR-FbPLjWLGW=)G#>b>I4bLG~Q(9QW?L6iy0KLUO#`4?gY@ znuhY8(Cgt0bWuMFIl7*yNgT@HVLVf&`qyL!)jIQJ<&^;Hws|e1E*i(ufHiaB7jp^J z8_Fg4grsL*iTmE1aM-_ITbpl^X6KipWA9_MR>oW=#@^&Q{&+LGng5GsP*b?{>vK08 zG*T9Zx0LLB!^7;FTrkOVHW#OjZV6Sb@vWSK73Imoo!2Byw+f$%I(F>4W_A$2E{d(=&{y%ZLeWZPhDV;sH2(H> zDRCQA=+>Hw7ck3JzEs{+Sa91cL1Y_^E7j3|Y&%hf5k`^Y&lcL)?!T8xkNXOYWo4ctdo<>W!U64jPdTJ;v4nLq_4(*u18wyPr`vN+v9AV`^OE_H<&_ zbn4ce>NH?@RNj}JrAFU)mu6-^w}^ve0h8!cg_m8LdDf}vT0F6`eWD{Pdt1%z9w;@J z27EH9@#&PG`Sdc1*YJ@_U*^tkkF!a_acAZQm(RSv9}>ySzGYU$RkJ})X~PMP^=c6k z?`>%#IL4)gu542{Q9BWnExBj6+!1v-q0j~?oui#}@gFLE1$aD6-`dLOFp6#po;!U6 zhX=S69HNpiR}s+stK5ZC8J~JKMhdh6_a)bVWMeuHySv5ajh_fsP5UTI`zp|BA9Owq zOjbi*eJ8Zd`#YWx4mWfT`gXW&)#X>QLuOn4qv)9*Qha-?P!`^Z~=viWL zFH1rd11Zsa;2orFI&||s&Y+h!oC*K;^9nWVg{ZQH$r^?Y%v^)m z@m9ZF{VM@WkCnK{n>7MfVEc4l*mg1+J+SKC9;`wAQe3^38+F&G;_>q@`Ee}wix=qh zR=&N$<>uTsZtnZ%8=-3OK!UPKk@EPMYDWo@q1x<^Q41V+4cQW+vm&#W*1(>SVy(C{ zGT3y{rv-Z%^p<4D&N57kquaC~jS4t`<_(29u!KPcO_4KobFW2ZYlYi6#+NMA=?-x| zP=y13QVUY=6BqyhVc%D}2a^p+p8`LH%8h1VM)y)hAQ<63@!%A;(d>Y)<38M|ys($r zm2)*rFcI=FFdca4CE+L{KoBXJBk3)el)?q*fK`+Y_!52sD<=txdj5CfDT1%Km^a`L zupR!3It9ip;(ln5#|e-9RZ-Kk&%opXsFUV{b2_BX1!EIxJ#fCAc(2hhiB*_BU$d2< zbo4iX{*mu@5WX|Mnqhzb7BdG~Q8Ol@(whw^i3oLsx6J?AhH)7s2q zygP$Y4x<<^7co5IT5HE|2KZat$zCXZ49(1Xb;cL|_zWE$1A*dUbD4X!H7;C0Yh)M_ zTM|tN#8VHUl!ON(m2|jK(Dsn`{FgRKBEoMI`d`9u@oZrv-xw|_P`0Gu?;Eaiq*ubG z@D>C9JDf{oTuG4kumNvG_Idm=egOSr4Sh`p33T>yL~AYVNOjmotY|p#t{fyl-ML3D zG+!^pKp5^v;^KgL;+Kj7Je0H}^F|x^13vRJK-b9Vzw)nD*3aXfbDW{gxtLEX>H$Vn5P=5Qgd=^os zXbVC6C%MvaL!vsA01=o7@d_C@#ds7IlMpZy;W^#w0)QhQ6t}#`1QrXP2u(1x{%02j zv?6gej^LI+&MzDNjj;Ib7To}xzlABxke~$C*n^}~2_bJivT(f9K)FulU?}!r413Wq zM%XgUVk9vKunp048=9Nzz~TT-5tg)805SyaOF{`tZZ!}pvQFR+LN$Pz5R|MR#r#6b zB(XD4uzb8BGxSpyYI3{*_yK|Q0kT|}>F=B3ny$ZtY_?!L2_iyV{||N&0+ zyhB1fO;^){T{EDpEb7irOLH~JC+lcToYm=C^m)%n%Vmf!dC9qy7>W-Gsdv7*{+NO5 zcwzwm0NY8ft!(~B&F!UWezM)oUbK}R74|i?3a{0PO?NCYE^Gh88qO6ipK2_(KVT|a zFYn0XFTb)bM8~FT!bnDIF-oJuci!mO3${0xb_j0JW$Fy0Pu$DQQKS07+htc0t?k*k zWNm@c%;44CTFrsJ|BgJm;w?+}Op9&ci8O=wT7kIn;E#89x<9|x7^pdm+na_b|DE4L zE$g-8w%Sm<3%d%NYnbXpvU2}t*^*}(y!!QD9~i0}AG@~q-)!$b z^QD@4PQ|3t$oFtqFiuS*Q)k=Z2u;hBD7Vaw=TZX7LbZL)Pb}?a$yrn$8+)pKKsh^_ zmeGy(mH?aJHjPBOD3Oohk^vre>fk1v;uKV@=E5WkMdcwf!Ib&3J}svZT_*cs| zrubv%cd54XjC>kmRpN4@?B(y&?29*?Q;4lyzc)MSPJok*j9y*@b*|oKTVaD+f|q=r z`n%Xj7}hgqbl3EG>Qnhd+kaPR3F)n$(23ZVBe9=ds31YBn0`2`)OxBnZJw;C=@OQKTtO zxI9=Zoi;QM<=>7C?xi^6?a)xL$pc#AeXEm-s!Dy9V~U)Jw5>l?BMNc|@x&ttuM*^- zU(pDZ>c4mZo5U~}p)+8IvCxF$z9-F8q;Lr-RHR>$m=<DQvRs_1HnLnon6WWtBzAiD#BLpUApOHa_T;a|#7tWk6#c-2ASfs(o6vlA&OLPz zG3dhcv7UKBdni`8=5~gtk~ld$T3=3k7F}QS1=cA$nNtMiFp3?NInd|+2j?A1i$1-$ zR<-0)Rk#H13AIA^fC~SBnjuvDyDhI4TU>_W$&=`WyH?bV1=-NZNCoT--gH*>?WS+3 z=*7gNIu1^-R>)FNxBKId59&(M`p89)dK|re0$Ko};i$sj;TvDA^H$#(?cP!8aURDB8r-3)97(5U(TXIky5Kqx0ZH?)Ne#|iD6Fx@Ba;l% z0hU&~lNlJih@%awEYW9}l{$qqeH()Qg`D0W>OUd-os?2mz}+vpYZu$ltM!}BF>9?0 zeoaIRfutx^i3i4C;-kah^kZrGUZ_|G957p+b?isPxkh^`)v>8#-j7-ANur2St^dcOV~*XGXq1^ zJ6uHLJw?J8QUFi4)#YRFhYkONlx}LenaRoj4a|VuL0A}?JUG4f zx4S-U_wL;cT9L9qqfpHr4@~ha=5Y+zD#3qH;`QOhfu8RAGY>6pwv46a=NpVPei}R5 zry+9#w8+T6Pu}^~TfpiOiVk9HX6uzd2oIgZRDJuRskiS`Ng2=u^cC*=mNplBcCg2TlMB9vdA^~5vF4~RMf97wdt@Vi->O6>sbNFs;9 zb0Vn-$jz<%zDL-@F~?isnF3U@yk`eujXPY>G;rl$;?5O|k^kMWhj0D3ew0E&DO#Xi zy@HPMQXqX>msa+`J(a^Bc3%|aS+9OKq;ngT6(w^+o!2?visi5N`v(}=d&rJUyeyP- zD0Yw-7MfYu^*!xi`E8Ehj!_dQ4f&ebqLZI|pmfNXaTRuE3@wjwWE@z5?oo|voCf3xJ?znwnhd`;sjWru#*UCGQf-hxq zTE~q2X>RNDU3h)8{5enm0gyq8ZDIbZ!EV`az=ncPeBwvu+Ob^VrVjnS>lTLc!*UrNoM-^sxK{LnG;3taXd7)Z2(n+43- zG8k-5XSVLvxo_Uc_xs)r%1EahrcJ(@GaqwX7m_L!whY>5CR69``(swEH?&>Xbdx`I zg-{3^Q|`N3lM$uH`?>FyGSbT~dhp9touce(JS6Y+d&_pIjtScQnj`TvAF4D3at>}C z9{0A16A_W6&mE5H$hpI!Y}v$i`!v;trN7x%S`0(-+BKaz#1n)f_4zh+Y3=zK&c>EF zvn%$%pK$Z=e;<=qi^?DDcg-HQH=UHa@?gV$Yc|UyB?Wt)s{*^ASAv z!JVqRw{wnW{gHgKD{!gwO$%oduk|Z2R^AHz%wKnd1&{T$1Y`)_u?b*35Kc8v@O`Uk z!k*EP&V$>k6u4yFat|FY(9k$+&`sI`sPNq!TR2@+Ib*e&X z;cyQ}q*%(8GnY=@>N=)>BH$yHwAAp(V%OSwpp=EaD_%GvUt?C)_4PfBXWwVfz6)Hk zm;7KPrJr50m!hx5F0zE=$a*|X;tquivC#p9;tb{1Fn`&0t9R${S071%5&Cpi6_U;Md=nSDRh`@N+Xj$#KI?|Qm9P5o}F_u>rW zHmP7Z269e{7hVg~b>XXTURB&|;m+D$bvYBokRXZU0LL3+L&0;BFa^-A$v<%^8T)=L zD1ecA0tGI^>R2~+P2F+&l+xd3&tYN$y?=m}La8@$HL!hMdUUtF05~$K zYQ;VVrGHRKZ$Wy;JvCXb3_O5XbjTuuG>f1t&eEzB0j;6l^Uoe6po5JW=s9t>vo>tIk;Q zRRoeE+?R2;{1xgowqX z8g`~rk4|(n@88aR>qrn?--s}eW)p!&VFGCxV~=IV`@-!gAB0Eo-*eFm(;Gyb%}}!yXmdaflc%lKBap)-N#$~y=N+J zp@jZNS8Dfr8*nh6s zDVomqne-`IN3VOOYMR~_y|yRJ&EI41)7NeQqXxMrTtM#Vmmn_Dv?8<+g3H%u; zD~Ng_o1gtz28XzerB(D-Tr!XT)L7eGUNF50AOrG;9E&7yq${D^La-24A`FkRAlC6B z_jxZ&StEb@1O36c=3D(KbkT$HngXhCG@wA030v^68goT>kIZEGArJ{9Y!}ff5VIC} zFJ2^#?7)^?>%%_SE~aJVyM;QnRX(|t0M!OI*T9X3S)xPWJ$^gwX>4Rv55Pu5RFoH5 z7N}VZbn|Eo-yayJsA%4hR?vLT$244 zHQV?9W8P(acTlR9?V8Yy9ewajLi~aH(S)SGT0$(8Ya*NLQZC+i56T$)6?B4ZFBu@rMFq7LpGn^V!f_W*!q54`aKU`@48^6U*R_+#3P<_J{dLGB$@Yn zXlAt5Ui?Y?_RO!<)MuH)9R*75#Ll*!Pu(e_IW3u!L!qT&a9XILRMCh2y%*hnC({M_ zi{Cq4Iky$5CbvarHZE>&x*3wj7ubH~k*9~Esf;;|#@IWjoV>E6yW3itN^|}ch%?6o z30=|LKJ+F+IO7g4ifj3Ew|x`m^Ln-3 z*E{Sss+MJB#j^8X9i%>f%rlHh>0gZYWxy$i%i`E#43)|(1r(rsc^_oG{&l?7@lygt zQxSW3M?YPpX16=ZC-pI~ri|VAKuFF}Lxlr3ExT2zUHo>)h}%Ysf1UlBIujVQNqgVP z1R6*0t>+llH}a;O82@OR8JoVaZQTFOho=tn*|)n&tOB+^6i+>CoA<8mWfVs(zw##P zo8>MEPpqVrUe;5G9NPBC%&^R3qOz+bT<&A&4vBDqEHiescV)kB>wK8_S7&XpTfO{e zj8AxPCVwN3x8S8CL))1|-^TIps`HLu*YH=F$Pasw*+|nKbG6<5-|K;+k7S!v3!R2i zC)v}JzQ67`sm9beSnSlrRyBWixn_hqw^zwXV8riEgp`(33Dx_y+_25FfA-#Y4|$cy z5fZ%jS8NW~W1-Myf63LH|J4HI&m3x}3u|$rxO`%RV@80Xv(cRA{qc|Kf{)^q%=v~N zTDH?vT!QSuJNwl}q1RWBW(S$^q{V#n*!P5Ol$Cwne7!~nW1(vHd;2ff_yQwE(?&(o zJ;zxh#U}R%bjnDIBumFce_B25l~}YhxTp|8Mz1$JbmwtHPt6Tolob^Cd_nEvjD}Ah zMnw2$IN*l837-k8>tS%KUne29_cFB3aWE^|wURLss>`^-fsyTu@~7UnCypO~12_xv zX7-UrG$lj`gn*4MgpTg(;B7B&Z9LF4Vs9qr?EPrj=dkiZum2mAftp(80@=jsMw=g8p(&t(X6}n+ z%d%|cq7t?*6NTfmE>5KRip3qD&wTtC?F52i7dg~e6+wQ&lBd&=4kcy%uclH4B#QEx zdw{wF4$T^(3gr<6pHn%a93X3h&Ikad6KCo*F{gngaHk32zhN|Nf{v^ld6RlE7k=3e zE-s_o$p%U^fNbyO3+yn27JSeeVKs%^-F$d5BP&znk zt@jnf796n^dAItQW6H5)JVC5F2wv<%Qv5?W;OTC?{OsPLOTc^RL_vW68BeF4WAOtXlQ=coM zfR}Kn)rm90#dsC6t#TB5z)J8OP9G9N1{@gg3=v>B?9BIWzO`{@2KF2xt*ZZei4fXf zeCBjnA6Wocz&`UjX%Pxk%|2TRV8OWoX0)@1FK4j?#Ni?|+nHdf4_ceN zHyeC1KFDML;FRI?>5n4Z7gIAC(342>k4;}M+s|$E3SX6T_OOwB}fx9l=`?h zfF2XQh`^pby5V#9q=z=wEUkw~4$4Or_g7b|EekecWT4Dn@Q1Y^L{AYnyi$5ibI4tZJ<159o51d^dXa|frM4$VQE(lznr(gHkVENgW`P1Ca z6Bl<#G`{G&=bRk%$@6aZ!}@U-_CpsRxyDpH;yw9F$Z9mUA>~ao3*G(Xpk6Ib>LNpD zn;_luzEmG4bv~#z-dD{GoO)nT)pXoszm=-aoxjE$Ta5;HPVm_VY}R<8nys2SQB7Of zp)}pL~JNKQ^hbXP2!Iz!9m3u;Biz_XCjy$5?Ti_H@e1yjAkLcN5 zQn&5DcVE@oKC0g6w?uRIT@UMfF$1@-JpTRb)gSxE7XR9udOGumyNMwKhvbgmhI4ux z##94J!v4w^1varM7z?~;i`yAbclNPXe$9px!W%NIhZ@~kvRqhC=lNIBeRY<)P06eE z?<8HLrZ#`>4`@z#XtQNk-f!eMCrF}f5N8ZNo{aVwiweZYGDkscVK&LZ|SyJ7s-~4R5 zK;dBrk8{*I%!fK^UAU5%!x&?8PWfcrpr1ee%3(>;KT%w~YxL^2SAy+%mxp(byz1mB zJgEFKAf_Qg&P4OBca-Is)RCrirnzL3_>UzRPt94L*9Ja&$lkaIEhY6oVg@5<_i0-l z`(s!r(0W;*kP z>ZqUuy$p?Xhjxx8FMp1E2kRQJOdWrTauT~`Wce@P{e^-2unV=Ws0DsAn%+DW4!WjsvlD{ML zb=K3TM^AAnO=U~qYon|nch$mG~DA`B;z4&*wZIm-$f(GO0xP<6-Z z6aw%jGB12P&^E014lQg!06LL!VYV$SDk3rW(9NO2%IL%i;lr5Q+;IN6{TNgx%u!W8 zkqiRPjm=8{Y1 zQiVA&H6vmJ+RAsR0yt_`Y+r|;jvTS?Yj3~A-#Texadj=}+D(P-+^aWjMAVf!fS()~;OQ0rQx>;RWFf3~@IF*m5XxrJ_yBzZr zlB{d4Kazi2DNRAO8mX7mGWX-4*GYD_(20gk_5F5TEHBH@tHV?@rLUH*89LX0Qb;+(sc=R9MWE( z`x#YL@!+S(kVmS`(7k!nnp|dkUo6=?Qzs$Eqb09%kN&Kly><^lr0xCvACNS7j!^FV zXx(LLc|~^bCZyXpxPpzph`Ju|y9}8qDLeXwmVrxfF?lKf{1WzVX67PjqL;w?)nRQR zqP@z=JoJLGG2g28rX)*CTbhR3o2}5>4DqYjQ&a$tk<6B?@^6 zkuoBL2LOIpstAE39(heXL6!{KDPV#5=XZWesod2L2;*ab5JZuZm$w(F-A$GN;!&E< zb(Vm#lj1OJqR%Y$RsUv-64D)d#jZa>LJ7;)R+oW5ZG?A${1^~nTQ5<0f}nJ?q=~6& zx&D8W&a$%KK{NQiW^K8~dn*G4xFMpMfC_p)godj>cdkWz{CND!wM7f3zO#7mq=qLQ zA3w@iE#Y7|vTIjv_AaCj=bS1FBh(unfNn?~{Dpm?{o}_sn_v7`PZTK6_I~<`r~+i~ zkfgwqMQ4c@nB*+4!t8C+hlJu7qG?5lNfbn^zlL|N#poaYGqS%)bD~lYEftD?8ukBf z3cP&Ys2}_W7uZk21^!D~o4cNslVgX6m=hL&W4m7CeqKVeMPT0K)R*e*CMFdSvz#oy z#H&J!Z^wT8nI*CtL27aPK>pzKmEp7%tmK5pBkq{Q_q~^2%yFYN;|+kXYR0>W&I7Le z0S~H=;M-YWpS3gv=!357-p(n2~)YH8+>ZkBBxvjcqV7 z`LEOn845L1FY)GF5!W|&Z*A$GecT&Jcbkxu(|^#R9=U~)f=90)bkpzOO9A9qV}K93 z{^#NK6ogYdUz^+t|-Tf`FZsXn+C6b^Xc$$eYmxL6P96OS$^G5kx-k&A%#l^UM zx$YEO#m3;phA36F7YDR-&X#&DEF``G@uP*v*7)*T`3= zbM|@`$jo0JbhgA*!&Wrry=A9}Y4rgLKVGx?Y z7&=qtI#SAP^m|EZdOwe(rSRdtMxIC+YW1wn)hzMrdxO|NIi9-D_s5cRC}89lSDUDO zgec`uQDD$bd%cKJv-Ks%tV%T5oX5m_bM*zYysoFSG}kU>EBZAYu&lB$a@7k8IH~-$ zU>_w@@h(dGS9e z#_oq=c_l41cBU9zG7E1N;V$Y@p{DkFaQpb~+VJ$_W183PN6x)2Z$Ex}?6jxk>A$7= z$fjwZ_qsXwL&U5kdFiRO&Hlho4wJl`oTX*-*h#n_Ov)H}YF+GjqwtxU z_~Ok1HW_Vy9&uGQQ5MtO$c4@O4O*j^0(8eh9q3z_SPyEa{dz}rM}0dhO|_Zw9z$`@ zQ+6`(U(6j1A4DGIvwuzh*I<*wQ59S1Z09}h6K6(}Jq+9F(tpz_(OXze6*v5Q%_Gn6 z*D6?{FAzJQQFQBfoDnOV-GLZ^fe)++PuLGB3ma^II}t7-(Ryo_2Mc zv(?i^>S?uKHRgvKu_bW}NJ{t_)Xfw+{Pm!EBBy-WBZ&X4U{U1f(_&&y9Mc^)gZY%t zKKB13$dGF!EUT@y$LP8AF7HBX`WGw$pObBc1f3ps) zerr}vm|We{a;2w6EO7!=ulfF%F(1zD`un}QolPg)g3B_xPkB&YA45c3h6^lkZa7EC zh^P4PrQ(l|r<{MHvGNCG_@JX(@{yzQpQ;d^Lgeg2@%0#-3J!Xl5T-Cu;!%$C6{cH~ zUk^prm+vo(_2gbhg#cn9%|hf6X!#@MGH+|(W0)1mT8gbU6c1;T(1CC@L{uW1^EQq& zc3SEyc&cMEd+@MgNZySe5*~}u zZZ0b6Ul7JIm`b-`42AikxU7u0SBror>8sZ#Jj97pgc=p{W9{w#Ta}+Yp?fJ#EOY<4 z6EPGbj}~hEHWbNXr6-WWCme5hQbI?PETXM=HPU?VThct({bE5H<0^3lQQ%qs4bN~j z6O-P)zN^rUUICJ`oG@|s?J#jclT!n*5h$+@7~hEmKEt69q&&$9ZbxZ0QpTXkkR{cp z1oJ>1+89nY{P9CXwdgv6rgBc~NUiR-<<*ELvho}i?vF^Ze30x8b<-8*siC!lbWjIKIeOwvz z-|&3s+q!-RlJ9W_0uo}1B?l%T4h9nwv=qdg(?AZ+8N9HoCy%0Y1_%#1cgePa@sP{{ zSFt<%m3t1xkaXTa&_4>IZWxN*v!Jk${!Hd2B2Yz(9^TP8jE^gM`a)5V@3`x-d?TEFhPPXm)&hF1fb;p+VSPgm)Vnuh$EnARd=&$M z+;nkA`FLvR76nf+#g*^+J(V82k&p;{1c@}HYYAsB#)7&ik>#B7x~>=P5_#qk@jV~A zx?aPntYdA>i))MovJxhYV3H)kgXAfLkcA`%KusnxXEB?RWDxvp7}Nzxdo7G8c952REA3N+k=$Vt8fAN$oF{-pH>P7RT`88UR@#u(jw`VT$W z#Pp|6LvdRniz!w|XkXMf18AT~2;v|{;xXj!eoI>Yo)jbJCWO0pV!{#2Qs8~Etrfuq zvM+hGy)5I^Dq6F)T&k{$G#R9x!@Y~JLj*w;L0Yb-FfB#=(Rzw2?nE05dLI^d%iK6D zK0olyBeWkwWd(25540Jt-H&PR!syfm|DT{$?L->19l(WmMQ#0*`K}r;*N1-n95WO6?x^?^xllogxMd+~?AY1ru7*BRDw{ zwyJE;PJXaEjNQN`z^LK^CwJ3+fy5Ts%&v{1;X;O?9Gxout?TP)w7sa5t68lkWD9ko z*~KK;ICB>>St?boq^f5ryA9lK6LGkXUFb#*|8U1xk*b8imlhT-ZnZt^rn)ju9?02= z=)2HGCO_CkJ<+wg+jyb4Bz8}nWss^@@aw%i`;$e90&4qpJ(l{*_Z@%Y^Bs@h6vP~R zkI8H`CViyIFu>L%fc;oN^w4^<_d@YAmtEbf!|NuRpB>^02=BLxv^>YMn?4|DiZgWT z^m|`hfz)sx(NNkWjayj*hE495d&hC)Fl!C<=kO;9PU&|~D)qBcrOfKINI{pmQNve&K)V(i7`+9R_4J*T{SH*=-2a&_#l*~1n=xploOTlH~4ar4(x z#ah?xWMj8gzde5AGc$u*=3Q3~>Z2wq#?r2Rxrr2#cE{s`X@kPrSyK`hJXD4Pxi07n zCS0AdWv{3~bYY}s;LXFGJ^Q3f#MPe#T$~( zf=Y30(j?UMi!{zbWZ3_wjwG}~WF#BnKQJCWa(ySdGNOow3wF!JRmClgUPe~8RDhO{ zrW(jmiRu!m0$$P_wW8`ZB%>V}@COfFUUCm3ks(sLwt%oOCcILpGe=_)i8(+vJrp@) z&I6p}Mai_tF@s%~bVOr!d~xD=dD+V!_l@-$qk^`{yJw9(`i$dewr1_{g$ox}s4$SW zwYAkDEhMmre8mf=@ebeVaRYtMa}r8v-+7GL{RY$tQY*1 zhm$cLVVWV?6>xdL*gzmMOL3$K-z77F&Li0fJb;t9cW^a^-a z6aIVM(;Jjw+=oyGMMC#Yqra7){~Zy&ic57HW$|w*NsJ!OG4F$PQD5xi)a4UKNxU>> z-0Q187)hUZFHZNpB{&HNy`%W`ByzFIvI!wvL=sQ?;v!8(4QK|=^BX6s8vYX`&vzOVL9ON2p(AIKuW1@`IP3Ly zr?q7X%bDOWfPp-UV+1HM=FJg@IcLnPByJJ&Op}6I!iTdJkVS)tC1z_(qx-hd>9Ksp z?}&h%wvcCiStTW381~-+z(5WKx)A+A;qEAX-lueJM;7O;-S;-rQa>O1Q9XHjKj;5y z0rp=skF6u-F`^|E6+~h`F}_?a9cH7YCbtFw?3Tu^ubv}NiNss3;p&T9Xb@Q;MIsrI zhMXuC6Teab3f|J|yCAu>0AbHsltFF`Q12BW+Uf!6QpF$^h^O7407#jXC(}r2Y3XSe zYuJgAe;ot4a4o~J>7Jrn0P4mxMXOWWtFHo+Bf)Xxf9FPHQf96UVPd~&whEq|pPcMYt-KM5JkpavfA~K7W1~xqzccEvYVR{@klQGly|viP#&Ko`q6Bl zl_qJ0p=P1$C|ho}Xt&c05U}6?xE*PoZtf+kH0jf>dMPc*Ki|O`V2MIWon6t#TRE%M~~~ z6)eo&IiS2VHI;Gmr@Jz5KDI=gMBH+JG24~#(XM`vyKn%R=icm;Bb6*(pk=_ z4cI45%C#|yqW&QO+B(n?*ICS89->bK3N zzJSj!b|_r05l)r}caNFYq+754G&s2S^R(=ZWw{F7*0)(A*Vd^A-bpD*)5yy?H&yI< zqt^7U!Gka}p8O&!_c*)K_$`yV=~h+_`#VGpc4S@V+w?|e`_l_7T8HQg*Znr6|4Juc zr_G(hQ22qSr(dXI52qymMzNHSMc>z}KM5|c{rAcV|@Q_0xSI8x+o*Q%byLgSWpqf2PRfCiI609Camuk106 zyCS9z@#dj73L5Jt3KcE*+3x!(IU7{3wVY3)Xb)ehb)Px6IQ2*4X#>ZI%&CU^Q#Ffm z$tI#%{F@qNXN`|3PwUPVc$SDTU{-&;6#}g1RwJRz? zafauWltXCBwtFm|jOiRb8^ z!*YZqhJp+6!1kYg2|Tqk84Fv7p2A%`q2PqT;|laj0T?(v^kPD2f{ZB3BzR--?dGH^ zmz$$bW!iy~45Z13i;v%mI0+CBsqm#(bVF?e14cd^r_WP=1Ngu|LkI5;@Ok=4tg_H(AE>c=8n59E)6Jw&q-DjaDNu5WHCMAK$E zKH?B|VIxGgU+5$K$$**wlc}RtuBHX&CmCq){EXCojL@)?mI*igH8VS~^#1!Z2W`}x zUa2!eUSa*+-XT?8Afp?FrIk!N_zhn9`#n<*a2#owKNq>$<0U2oI@BDR&Bc>wGH4F$1CNlX81)7W`6Gh?Gv;;n#qIVo_1$W0|& zC`7W2ZmSiguwIS1p;7Uwv?9)%M&zGd!2%0iYokKlEN8wV<`_A2MS{wjD>NW~0`SWp z!gdxapgkFOGBRV~^?_~D6N-|{iBn5jkCdEe&&IOS!fYUX!8M;fQ0aPpm2~KUmI9{? zgsmmnZTM^#IbaYwNYeKKdEtylDJlU=bUJT9o)CNu z4U`|7YUQ~YUnHXkV0YyF4%e@B-nzpPhAuUW(&4M66UmVnBBe=aQ1A9{^kcWpp_w{mao5 zMNzmFQa;EO36HN(lr=PW2=LhcK2ACmpE?=GNniNbbI+Pd(mh16ul8_EeP(71$8Ngk0f5$j4?38gbwM;^725E)FPcA>8Z=?-+6S-mQXmXcJ z7m2tY_u7{gbXSUqiP@p%i{#J1C5(%j*yPC+hDL{<3-Mo)_=5(~u0H4JN{riki&RR7Fz%N-99Y(CyU!7McG%=;RiMJAoju9T+UTIDh{4C1*6ULCNW@vbm38 z&<}}g`pNY2JKwj?I||i|?+zH~Z@rn}z=Pa5UgTwpT zTo28L&}Z(uR&(k@qBr;W;5om*C>u$a+edE67WZzdjyN>jUlud3BywoF!>c#^YwbyH zs?5`Ya&sFquE{;(o891eD%fYc+j8{%2I@rW`@vLSDfI?-(znNkK38d3mok5+pMu#q zmZnFL`>i`;Fzwj=zpK5YVWIx~xiX)YH+m-4igFArm)+^N9j5!9tHJtmYnA`hqr=^2 zT^@}p)07y;4i7r%|909ErpNHWAtQ_G z8MpkJKnrgv@4u3*9@WzpW3?-?$3BC zVbojrjCVGCMSRX~{`sO(&BLk0g~#c}@69&^q^^A#EZx^Qc>c&_LFSjF1oy+|TK}wb z)=iUJE9Se=yIqmVPimm|u88Yj!DBp~Z^LM4{fhtX^)Kji9o72%TlCR~=}v><{?7E;kcZr_PAgOg|Sm$f2?E zhRW6H{kmvXJ+5zgR@|n*v_@j;{Mg=${*(L>U-xgx8!^_<*nnKEb-g(k`rJJ{uFhXF zUpSy&dZDAgKcNh7wD)th{=g*Wg%2M-P*eF+E>^U}oA1@Yt38q9!-IE+S7KtaxP9Ei zrQ7@b`KKp#Q#>-B{qH*_R=vP63JeuLeRe*u-*zKsUnvriR-86I>R9wwva0p|<|=1q zX7=lEx(Gcy34t^MUVT>wW`|cw_C3(PTC&3&JOV}$4JWSj`*uYqpM44PzK2w4gqA}C z5g8dtIbF~6Ilx$URmpSrWAvi|kepix#h* zW^unCf+8L-%3i<4*VuJlwyFEC^fc`jct+YUrMKE`8;g_L@UX6H`7f2f*(J@~`Tu#a z{onuW*|@S_>3_cK>VbA+=UvA#(_X=mw>_=rV7^makX6NP%QWj^Q@ceotL4+VhyPz+ z*Dw5|2>$PV-V3F6po*rpko8bx9+hNhm^YvM^Kjdb|NAHX&;OI7^*M0&|NMn<(KmV5 zWkfYosH{&PZYOWx|L<$eIbK|qC1|nnU6>1Mb5{8b>H2$nV5*~ z-Fp_>D95Hvyg`a(|IhcDbl?vq3!O?}LmTC9?Jhd6-#pam$G4Bfy-k(?^#y4vP@5SFr{{TY#1)P^!A(P;+p1XjfKCL^H=2V zd(I~aU;ejnS$9_A#NpwX>zq@Ti>YIWCn#q+s;>Rwc_zo@^(;4VYv7k_Y}If6&0QR* z>$RL+5Ir=#{CujVO^#R2D2ylyV-qIIiPNj!P80DDSd+2R zzKdxY8Kckoi)$3%#slm{0RW9!?a#o&ak0mC*b3IY;H0I7;N+iqcO_bdqHknpuM4YL z{_D@8lAJxk>hQE`cI!?53w`!+H}fV%ZWj);l!-BPNHuj#itOsb?nRVflr6O4xQB-b(ITZ8Lte>aM(~?R#hq-w4@{BEejH`wX17=EfGv} zvv){+pP7NMN?~>u^qXp08~@oNBW;wDWBA*0^XrNQt+KA}hvf6^31cgFEoLjY?yURp z^W5XI>1!X)eL7Nb=%qN`hxf0#ncTc-?(eJrU4~rp01sw5oW-_xW~uzyWXQ+l@7w%h z!MmgQP`rw&*GiiQ)p?s5J=NF;w{}v8^SJRik#4}000Zjd_4U2GBiK~!{U)7w#-3+H zKgp8t@40AQrZvSC{G+zJfkVW_AaqkF)p~_9ww~H{@j>Ftv-c($nP?4^Gs2YCmciFQ z`nlQg+`q@X8^^_>Aan^f)$emX)N@#)y)k+{uQCTivCVq5%`p}-o=@*4a9?@48@1TOB&cCovN8T0QN%&1p-PKAk}-&uwb66neJzdJI)1nb(r07nl6| z9!?je)h8}|Td?((X%Gk&5mBSU7Jf$Xc-7sh@CGaWuiyFG#ltO<@6&g(`ENTe-s?0a z(^ue@px@^Z*!sFI_rr_HrfU5NB@X}N$$U&9T@JQonzZ5uVf`r{q6~cNQniygX*-hx zVv7sAbjOs6Q_xBvU zwqk`cH5?g`CTgHf3W=N*)5|Y+@@1mmTosH{7 z1>sU81^3v>Myt>E$&k(t5M*ngsS@*wK$D`p{=2)p^FNDo20WK8ZPe*;;uQU6T_PR- z0lfVyEP}Ae@Nv)+%$g8w4h|`p8Is)ewY4r^K9HR)3AP*RRyH5eA;-^sx8AHn zRRYnAL6AnUK!TOJv%F&Jl}1#6kr5srxJmLc0O9e0?4)OF(Pt3cL{oB}mpb~)Rpw#- z8^~$r%|! zEeqzEQ>%X;XYWFaC@*#C1>4ILW~4 zn>ansB;~UMG2+BHOOz;t_vgO+S;8uRcZ4RaDkM)GRCpSY0bXiQy+pH#d@71t&P5nS zA;Lt`UyQ5^A3{FxX2qIPfxyDP@c;sndAA8P{#kE<4a9qz+dhm!CB^^IzAMu$-&{?j zJw9D?xzKXWcNcoa7Z$PiA;qC|poIUiK!TpJfq@eIGKFGBPdfdU4Uwx20-mh?$B(N6 zNrym!{KAndzQ`@k0`>)&p2$mi*u?#F(K7ZWmOba9f_gPPNCo}C<=JLEpTDX z$nMJl*4Nc}Man$EN!)y%=HW^S0knZrL7(+PKMEVCTehSeTX2l|W^U~68yqkkU|&Md zN8Qc_hBD-GVeyK@dPAB+tb30sH~wMedrO^f9C_%(Wu9BEBLR#ut#n_9J?PKfqW$P- z${krM9F;D!Wu3yXV1ta7)g$$byq(D|V5Qyxa;pRZX{nJA0R{9CC#s|1Q( zrw{jVeBY?6T4!CHdt&>UOx4{6qLuPGc1{#x2E*op9K3_Z7cwA)3D17NBZvg)> z)%3As?*n!+vBi7zB9=e7Q(e|bzP6yJH7=WwT6BbZ(%sIieL?f~NQU|KzR8nSKP!~! z^=MghI#?t^-UoCEXkB$?%61)=`0>*8a@$D=ig=%wc6kwaG&k>woTPWJoQ5zcOJa_oQR5RNLa>D0x*~! zrDrqJ5P=l6AQqqr5T0HMBx(z=4+fT&zSQyyiyCO^{9NvG2hbUb+<`oqtJv+MIOFWK zJ$v@lM!39Lk;dMK{hvrAP;rdGX2a^=74iTC21Zf=@xUNLb=)7y3Vmxk!i|Vf*C8|E zjh1l22#$4shd9*j9NBotz1Dd0R)F;ugvPlLMLN10|1X-ZJRYm9ZJR@qDIv-f8Hyrf zGK5fMNM)uHA{iq}2%*SOqKIS&nT1G*qGT>ZlBuE)g;3&;aj)$k81MP~e;}=hWbz?HaPT!ukS`sm)oNB(ympK#wc*JH%V72275Ol4@L7?ts99QxT6z2vgiWLfSWAPv_Ua0TZ(?uKmwi%7>& zl$NIuCG=P2;;s_I5QjQ-}NtB zbybXfGHeV%i3j-82ksPc6R;Tild!oDCn;vCWV}b*U9Gl(G>F+)DIOA45WM~V8xmvI z+l3h*?x$*nq?A1Scf_#S5kQG+_Y2P~c>s@uR|qZ?uuFA?k8A6$-HQAA^$Y@aMaZ5G zOFH=Ru6!0%eLW!hF|C;_6hcygliZ`6oci)}JX^N(7uivh#LO#84cqX4a=Sf76AW?^ zxGP~``xl*RV*W379%TVry10dm-r-u$M!3D2xh+oxg_rjENy%Si8%yTO&g+2NfAtv z%}~Ew42r{|Cy!5RNx9G560B!V0xi69>6CxQY(({~U;A35k?hK2Fz zBHr8mwpUOtu!|fXhFXBwGLcu>X7&4VEZ42ZLzhW68qIrE^@(_JL8jS<9|i)k(KDIm zpOxkPvDb|7)JKRXs$^5{ySQaMBmaYl;f>Y2(|q(vS7#Mcy4I(Mh&CnN#owhc@>VBC&of7n@MmLA-_;->k|KoC8oS0!QzvV+iBkQ{7YQgUY%tVDWf@6Zy zJsUVuw`q%U+6Je@p`_+HzN3KoA&2&1LpMJ9rS#UIYljK4`}eOS$9!gIN;?P-bHIi)nH1m)GzB78g^JAdv0U( zIT6R*w||B11_|btTKhiTpks9A5n<*>)ecGOp4Dng5?x%-ci1uN;!z5{h_r>y`+@2d zP0M|ZFMG9iD5-?7QHRtS=#)KR>=ftA-={r?0hw}ZLkmr@`j6m=XZ%_YVG)PBVsk@y zCoGPhUT093wYZ%xF1Y)^EvXw15-1~sr*0X99opL{?sY4DNB+jFX4l=!Edo;dwi`LD zIJ0KwrRzt;jbmdd_UI8rxhi&3})y`?YX2|q^Dk$Y5Fw=2&xyhu%uA?kRi`0XH>K4G1Ki|RY+ zwAd5bJ{#zMyUnV{smZ$~fpL!jtGZ!WPNOr&twfE<8r`?dlQA+IGAKV7hL(ixesRz_ z`kWaR6Xo`_1SgvBrt{*V34i1It_YcJxZ-&tjg#@khBlwuWx~Zd0(-Xjo!0yp8DPQ6 zDLm8`=euVt$@Ys6V|c3nX9Kq?w!`ut-J;J`?)!f(05#q~rC1%|nFY{nr7ujK03D+8 z-=9gu#moh^U0<6f>z(9~;B6+%WJ2L9EN}XUF8MQPh?wdpVE&BV2YdIGg~0j;C2(3n zSV|KSR~JkYvzV!Yikb6*9sD7)|N6^UYYT>-u^>9_6_~V;TxjcI-c4eTp_Ug2S^D^j z=-0@`2IsVGcpGcxy8>tKL0bcYq{{Gkrm-0DumOPiXWDGtfn9gV1#DZLI7q}E&7P4k zRS9F_e(}Gc{V=xI9MOUIs6=5z-B!CgP)p?7Go`R7A>n?=>4s5$z|ZS8EH|5K_#d&n z@C@U-A#0+Pat3n+qL5F9DuZ4CTTEZMK6-kLI^-a+)9m%W?>owX?Hp1jxQW0(_UVvG zFVY1dQetXQ5N!JdX-NnJ`x1hi-$JtDi_wls&eM8neRlos6KFrk2i?sg#^iCWXXc*U zd<^U+gj$%$4X3>QF)_gpcN#nkoJL|&QsYoxFHostK=v6#1{K=@_6QpI6`?=+Q`Drc zr;9ENK{tZXu-QTRLVadaCASA5b&}|rxnLnkl@UM!QJn(j6T}t)kW9Wa#liIzIE<7;q-0+w5Be#a-p86%6F&>C97 zL1n0e0eL^BG~+xgZtesB#YTUJFWn#a{SfAYnjU+v|r;}p>Wuk zBJ0Eko|7W<_<*2Eezjzah7kvFRjjBV6?ed(H8NtoV|9TWz&OIq2@4(F;IB`2TD}l| zI%Y2eGnF_%He2q_Vul?`BPWq2@XmEUR#t|8v$FCK}y7Tfggo|IRmod`r>GPx30*$`XJp- zYR!;%tFA>AHpk$S^G)Qg9P65ET5XNp5Tp^6?s?$R7Xxj*@e+<4eY0Cd8kfw&o4ryx zPDJ(Hpm=qR-Jkzj;TM`~(m{?#(n{O?GZ?C}4FAk=b$;_Z5~%+B>;OfQm`KPQFUEF; z2MLUinI|moiqt4y=QIjG!pwd@on2A8NkK@bWv2DO{^?!aam}xlezG4hpc~kFFM7P| zw1wo!gS2A%w(Csnhze#8K5Z?}Pjl9?6NC$lqzg6wrQ8oU|X z*9)cFPQ`icIuh3(dqB!KB9%o%)@ZxOV}q{^6jq12n-AoBb}@u*4yBh@(2DMwN%AT< zd@6HiS)yL27%SC7j!{pIjQ!H+;_gWq(ML}wDFiuDI;xa}Qun1=*fdVoT3l)J2;R-r z<}LX75_8FT>C55Nhz%t^gPB3rJ_FgKr+qibTK}l=8?nD?W?-`w;`o7#ZYRfS;3L{aLE{2tK zUb7}&z$m?zo*O*J+q8PU5zF)zcGUWNl7!hJ+@N~EMEE7@<3uEv;BY%o+npF@7t?wM zs4PgC8V@=Gvzx8@X|#+wWLx9wU!N%jy9$mYwUZy)*FB3b!l=tjDJsh;q+w+B?Dgti zhdtaLPc1)PrfYa#-gc&R^+)&Ox+;rx6olXi&U-sIDDO4pPpf|HD$w3r_u+~2)8XLt z6bE0NF6tinanoy2rnF!3^=Y<1H-%?5-?p@r;PAj4Cxiv{y{!J+_pK9E0VljEKR#_G zD%j3_=;7N87#Wi}n(ODM2PdtR8ACN_uGS0PjVe+de=f29^VzUPzJnE)lMcU9KKH%6 zP^w^8B`m1%%Z{9Exa}&oV@FzA8ikh8GbM@w>*r2);rqta0KaU&R=+Ls>T34i*25?W zu{cb0bAJRv(x3@|v_f&_0~nS5&VT0U11l$;z(2%6NU|uv-yCW$=Dyhkw&_903I@~S z)X8>+s99`h0QdeXKKkg#UBLHr?w!vR!K@Il`x>|^C9pZy#II7Pm9<;LPulA9G;QB#LjY>Nm zht>no5MUYfLte+DsM25!OD6@Ar+9EHA)m&$aiM2Lt(V|t#IqJeOOPO!=2F+*!ckQZ zoy@0w++pfOWd*}O@Ou0f&JvO%+)ao}9??UfF!=&2*l}l@zdWFjDC5W>CUDlA);4Gh zh^z?z#Y<@>+${gTF$x1|doW|sAMpheIt7YQiIFbj(9yQG6lI1SF8CoABi6 zhLPK|8x3!(f9kn`jBcl0&t^5iJ-A`VMVfD7+Y6y9Y8_J6ph32UXr1)%pgfM#z9%%5 zy}gzVSRl~w!v$})G43otFYVFued#ve+!YIlcs* z*0V+Cqj@ zK!qcLwN6wmc{Z!!Sih3Qb`c7!FC_dMWyi%uqP-?uODIL)<4?FUpe{XsmKKen@`Vc? zFE<@WlPgNm#&ZQGhZKY|V*{>q4$5Zekj^Z8e}8Gd^Wa6wKlmvb(CPcXGBf?;yf$Qd zL8^3dzD-)1THz{(SXU-P?8)OSx<)tlrvFTSxO?|9?~}8o9$KvG{m{EmU66NV8*KPp_lBc7 zyz-!^X!eG?VvVt)eDMjU25C46R9Y6SWlQdBfB#hJ`1toL>CY1>mP_2M>MuRg+;6}Z z(4-cp=~Bz$o+A1n)XV&ST`W9!tdgF^v^ds`tM~9xY~S=@r4xsxkze<;v3`UTc&5W zv#1CbSYP;jm^I&M{B!(t<9-7{8d||t#t*tzr|vz|aT?z0^l3*+uh9sG$cyBx6<+WvOuE8$(L{aecHosbYK!8KkNTQs z*@AyGr?li~Dguw4cN=ZHa4Upyf6$3$wZWYDxkKAL!oq9rFy+O~F5IKs%|YE8$~>jT z7IwCyL{fWfMu?tLEvPwc2@xTC`W>Q7ng<$XUWDD+p2pkj=qqygbAxd3 z#?d*`5XE6W`=-GrgN*ZWCvJBso<~NqX1k%XN0?^%vF*iLFVkdLA8gy(y!`ZLnV7+x zR)vx1%JWG7m>c<@r=--hKh`N)zDkg=)7)U5(CFwZ7|ca6%JTP)M%%;v`>0(M^z0kc z0zcV$J)bj;*`sCo@SKioY~Oy`FM&PFD@^cs@mPxsD=kd&j@_5R`uAv)jGp!%-oLUB z+xTZCM2)M$Kb`DY(n##tJVtL#Ddfs>KQCr3apR1Ra)Gs6^5D6~-8Pjh2fey1Lq`7!3+T!nh%_fBa+8@wAr4W;?o(^ zV$w=mjaT@-oA1E7#mo!Bg@K%V9Pe&>@Fi2>BujcLV?;iW;@U`qwfN|xm=WrDx0Mg? zV~>yi_@NfVZhAA6DXgB; ztK{&`YX{92rtS-d&U+94E$aSt+)#}Q~eqCj)!!5w&B;LPH!+5CjJh5)E+?x*^!gwJ%YoP(7~p`bGv*h@1e-h#pBFV`Q^@HSkU` z0|dnlz9bs>OMu5{OZWeo5Q0aF^2OHd>}-5x*w9Rc((~yH$^gHoK3^aF`0+4=j&Z^m zGr}LeXDIvC!8lrKd%}pAT;sQF!=J zmMvoVr|$&PX_JzYa2DX?CK^dt$i5o;>Gtch?&QjI*KbwP1JfcB64zn=9$cTcw*cu@ zC2Dfg^^t}<-sN01HR8ai4wsaHS}V2v;0_LA1<#9MnHx85T$$XsdTaa=!u}rr815eD zr_H)^=ME-x#MK=9PjGq7EiBp%r{HWlHn*%~WrJxOdU9B_Ac76oh+>UUaRGQEvx_jVmE24i_6VDK=RAN!9^zf7d3q6sQ!`A%Z&Nw1&w2L~VcyIkSV*(z+EhdwzpUUDWTw zyU#ZUq7Y(%d|jBR=AqFhN-Der)FRv~aDNMM&1!%?7GAgFwgnG~pbM0`^Lv@awEAHR z1GndN=;&cbpHlqQc;vzJ=x*Nbg;5K^n7={51ML%>WQATqN}+%`{)?Pf4JLEl2aIjFEvUT(yN<39lLgY1spR<^dt*gV?TIqyZD(EWhQkU%H39C)FktDJ1;Xast4%e5PJ?Od7jy| z#GZd`b!bf=btWpS3C-E^Cr@^RaP3+-Lig*}FSri_VMnHK*4G=W69Y#XD`#LLh(v-0 z7)jwdAdII?DIiWbHa`C9psT2a#QK{}SM9h6B82CPJNCF3nA5||%$I+j6-<_um32IE zK2F|$#GZp@LZC%3tw9!e{PevyU+vHlX)Y2U2Vh1hQc(2B{nMC0Ata+ikPBrzSi^5z zsp9v&ZEqjXy5~w|i_tUDfx*#p?S?be<4Ix{K#cdyn%>MTTFhL^(d*A5S?svk*qrhg zYv9f@`qTIEVzS|mMl=jMFa`t_#JlVrUpa5el$z)1UC{P^ z?cBojV7qcvQ0Tt;l4m9VE{6p!+1riSWHPZNi}>Ez_c)2dlX+Nr8=di%#@`YO>srEH zsa>cX>Edh}yW5$o)Ju+umRqSeuM?)+?G)4Z;N)Ss+gqbOl4eIU7(#T@DtH=w@*XYl zoP2Q6D*X`uQzNY!g9D6uXV$Gaw3k}d$$tpiIqWoh0i zrTkuSH@Ch%mRmNKGEW<4UB9xerk_7RTDRNNNQd3sde(~O*rhB>+ecJycBk}-&q%2-w{P&t!vwC^0wcxzWkpnDh4BU>9%>BU(jY_m}raV=-SNR2%Xf&_Qwz@bo zvD8p$zvJ)EWgp5F{K}Z^QlP0br9b+-&Yq2HPmHXL32*2g`WS`BtgD9VlHU`L+zT|$ zua}M~rJWiS9<9FJW;1Y>C2>?_%zShD5j7R3j-|69wrXI)WQ|x`Ns@oE*@Uua@ zASkH4R>Vln1ke!&6W@z%BCKiu(J4Ejr)?`E zJCvV5FKP%q6AAk8C|03|@=f;sysHiwd^6JPe(&Q} zY!&`@OFL>JK{)@=sG|{|aJ-JE)iA$8f9S++U1qQN#v8{aGOVUv7)d6{Inb5mM|IEI za`i5rE-6@epyhS<>%j;YgGB$mS>IoDnh0Lb-C_}~IpEK5_l~yyLOgTafjXJ`AfG@! z|L7YNtevNnUXSsNY+!gH!8TE)dy>n2Hp+8!+vG$ab*tieYnB5{!9I`J{ewb2vr2PZ z?YPH!*Cz7b=PcI`rHY46#|RV~Y0a)zsA^GYEA9Me_oB>d((s(l*4VcaDt(Naf?GDx zY^!W575J+eLvt%e+si+$JkwofOpE5I&X#X4WL_A!+==XZJo3Y#;l5aqWZFS*LI3!& zoqB2SF0a2=ao#XD*BI7z$WR3)VE;bXbZV$}sN@FxK#6%f@*PPNA$r z_%aysZ8bIBnf{+e2&Oot99scxpoQJ)J4^{b-^a{vP@b6cn~GO1!r}Haq>4Y@u|9I* zf&CmY_i`In%iD4pPP6zLbG|WRl!kaRI4`s@#zbj9#OOH~x}Pot?;_*`cMzV(gvroR zd*w>ih3Pl&YkM_=1bUg3f?B=3^71ke!;p>CMQ#gz__6F?&=c#(@;vG3#+;rw2`l)` z?~uC}g#k7{%P>CZGsX|1qN3{xdk=vng>BaUW#lHJ=3cG{eS!Ogh%(rVW!PW-d8c+I z6g~f#$5Q|+UXwjhxcp8ep5acyniwJ%4_NAxJ)d)zrH(XL*{>tS5Q2`0%hv4Nld`he z>Y>$W(9~X49fxHx%y&N{qRtjz81d!9&qyF%)X*Xk*FcO(kOKudh#6lQ&WR^VQpzkNHjeZ2Im6ZE^!;8UCOO2B?+86t`? zwa>qN#MG3yq7lm!6y(r@tUEk=7SsUxV6{`H9&XaKJbLsrZ9KwR;*iV)X*}4%SyfdZ zs-iiT=W&NEoK>_6cHvmdzVG*@O4UUDCl+&p)d_d^Vp`5^h{o0tJ{ih#Rd-hweT<{* zS??4iH@5^&4v5$6x4=l9P*}iI!IhX{Aa6_M=g^|2XJ7vj5>_G%hPg12|G={C9IaXE zzoer@N@^(e(JkQZy0m{=Re_0xg>(P$luuS$ndp?GHY8I;BHh3bQ#a&eAli3a`Kdgr za)LENQ;QtGoda4e;rCvZ#N^h@+&W zgiL0+;nHGNOkRCPfdC90dnc;T0uLP%xiC(ft#Gvaz4OF~Sk3kv0a*QWK~TJqmuO^zE&x zfBV}dLmkDtMR*g;!Q3urq*aNM^InLyDJIj5kQ-Od$uD#&-HhFh9R; zm)ZM-yIHKv*PgMm()oscb45ne^0-4o$|HqTLM)qG+ETvZ;6qvW}{K=^U!3X zjk{(#e_`-5{|l<;V$5B;e9T36`3r}rHQj7Ia`>L)LWpbd&z$j7RG;TGg{vueCR$Uf zt2WRk-G8++H_b5FQmmWT2s~2G@G1u3->P(E`d?_=p=78J z;8E|ka~e}@j~=G9+RbuxSZPT6?~3NNc7vux2ep&6k;a9r)_eyQ)9obwj&3UsrXC$I zjWf;CicIKeojFY5PMu@;;*-@Y!wqG0)Iu&^g2Ed%eBQ_E^OUdtOj@3;1c&>rf$+*V~MUS~i1-&iglp13HMKb$gJ=Mzia>1+_W^N0pZZJ>@xota(3;7~NHeE9THovE2- zpW^i)$3?;m^n+capFW_@rJSbvJtvIIIk4Msa{j!llxQNSp836KeYWD;k(KNE!l#X; zUB3Fp$~szcWKW-~e|z#<-{vSe>h_lgdmQ{?N;T)hqdAYUP6K(&_I|uU@pxmWskU=>gp1zE8IiH6 z?Wd$}#dr$eI@5}0;JK~TC${MJb)9~}lh+3C$I7x;4!^nMeoHT0F7}n`HP)`T&QHEX z_WA4SBvFN5&{Upqr>=M9ACKs(sw|=v=1E=uu<7!h zD*>qu&besUkSzqNFd@-lpSBpu&)N0@7W$dk717+g)+`)Drl6$8D9)(tjEqLcU8p%A zz__!^YU0Ba3H&KR91>FEhGw$Sr<*jLgkLd?x6o$6Ym6`gP-acZCr$T2YPkwil~KRvKHwY>>-FsUmZJfO7d z{RU+?emf3$`1ke0`ft~r>j>IF^sMj1gKMA^l}%j|)=fu2H~;;8%bOy|s26=sUYz%y zL|sgRzwnvJ^KBy%Y$PRR()<-4%!^D1b04@5JK!^7(F^d?BJF0uWHOF9%;WCC=%kNo zOT&mX3=_$G2W2>ls}m)sbw!;;ta^9aH+}nS{pRE$)At5xW+)OUyXnp#AQRae=6uc2 zN0 z9mLZXa4E9Fss{tYo*ajNAIUR9VX4vUO-df}CEl9g-V87J?Tl-mTp%;PIvOYBMnLS& z4G|*+K_P6V9fL}<+s7Ov*dpNW?M7DD!1NhN|G|Bnwa8O7_FuzIBqS_WQRX=z^-~w}VwLB3Vj*dt3FwosbKReys$1R38n1SK(=p z!r`0a?o&Nz2{;h1>>+M$dYT;=RRh}7jawq_xH7-?JU=TI`2s~OgbiS!&^Xxx*v zM^cpDSKUPO^Vr!yn2qFFps4CVfI1=PmN~MPNCI8eM9c#(vxQ#ayMK58w1I78kP#)> zf}%;|LCo!@_}r>0B@)TFt7<|J>lo^i)62iUkib>cHK@apEioisC~hz14MHtJCMQ+! zaq+3CN1z~3yfD?ro$9W@i$F$;H#r(9Td1lGfln?}9pt`t-jKr;n@nY3>+OV~M3);$+#tUhX+dmh$FFb+BpW1ocQO$?+h z{FlEiRu@d7RyC*5f$21MFQzSWH{h#q=(YPqhk)g-J6hIRR6TGd)+HTijv2lVeMXP} zbrk+O1%T}(dcQCI9K(xi-gHPL9_to>l6{VYai$Xww$p7Z!|Sj{(yhy zmC$rQ-%(=l&G4#iP^I>}y7%2-&M|yU;WNiOk0jpz5v;E^8^NXYNPthp#XOa!z@07e zvql@|;|Z$Ho&1&c;djO3Pv2%_uBSOo>BZtrImETjI4qn~CcMCe#kjnXZ}7+c2XZbi z*?@^W6URVjY}N>_YjQ-`#2jmaE`jN8$^*&F4j9Agx+9`g3E z@`|W&59SR2aA@i`5t%!)aav;gt(dUn?Xhxg5y>K~L-}k2j}qU6oi`SA%vXvv{gcCc zl|vVEu@mbG{>jaA2}eJ2)c<<>`c}$fp>Zum7xA5q$LK$O7La~zCr+=#A)MZ1DWiSx z7*(*YrIQ^sQ@y6gUYX6;+GFg3Ez&MqF$uXwXAZMJPtIUd+INMWE{rnnP+fwUKrivSN`n$}(! z84BU-@O({;sJnB~%(CnIJPsFGICpgpM4c~}_{Lw=`g5Jm{iChREVP%7D>7J5YU?$b zzcn^2^L<*nwMFdiqF1pL7eh)X2ONx$-4ymXk9L{{kzT!a3=1( zQs+*F_{1d3bl0W`nPYl9Sn*ZJ}O%7A(OYgYP;@eOxSR6&+XkWEZCc!*uJT{jZKM< ze3!6$#4?y0)Aula@0R8y!H7XdFIlejJPOq3ckHZd4h=u;d{Wctn$$$We86{|( zwp?_erR>D{lMTX?D%>qe!;X=mmsLeoU`tDj8NA5>Zy?ZmH%<{al4?nForXC&2}VBF z`;pLr;1%nxlp<+SeuG&1Vq@w8@I&0-0jwvYMI%~2NZpXHfzFOx7YeL|(AVfzU$?f} z!#j^mY#Nig}f%*>K*ACktQlXW=h1>9K?uy~Th zCr$QlwCaTd!iNxz;P?WC3-AS|yhDHHo-gJdwYJ{6EnrEY>E<@1VWF37Ko^a=hX_V6 znX*Vz-@yS5EVEexPCj z(uad-@>Wn4&ikh1r=UzhK>%mC6L#5QcK`W9Lv;Y{4VQ|aES{GNhC-I!?@bB52g3pE z>Lj29hXXN3h49j@V&Ey3gQJs^9qda;mEpTD8kxdbS+97mXc5zTF90qZ>wSL8v3~4VVpjpl*pr8YYIxC_*hgTlb<7 zA*^Ih$rBwBj{jT;uIkh!3~yV!GbpmzymP|Bvi%0Wq$RZ>t= z()D4Z$w(j32(2OTK`B;A?0V+gw}%s?Y}$x77H~aO)MJ=23o9$%BeGwz+u)l|mG(oT zYa_N%heL~(`Pyv<%na$U4>S?$I}iyG?8Og64)tHmy?ZUF6)%FAu2i$p(*5NoOW%>u zp1u2_0z*G0k@gR2Izm1}xhbn(%Ax97aS!ODoJ*grnRE*7KK=n9o8)>?>6z)RDt^+a z`6S{qkx1t9=^3{=3v^9(08#fLxfK3UW{Ys#Skh~Mb+pYyn9UBFJ>oKHo=u=z%a%KZuZ&Z3U27mn-nARNJ+#gVD>bFkcS<^1ve~qGXgXJpkY0GJKrr?RqEQ4uHX{$BD z(}&n5wt7B)9H-{|&dH2ZNZ(tqiuUB+qk`Kc&Wx{ek2@rnnEhcK&upRlImq+l6J4}M z()^Z>vb8)~&N3{#RBeY`mb8^vo>9FKpw6^b?arFayLBdxQ{!5ZZj_^f_&J&aw~;OO z@}`CsrZjf~A0F0V49|XiYyV$w@ycs`ZVRPBbR$o6+5A;cS={28qw%Fvq7P`d&t>a+ zt8(!C9>>cv+VkvR4*Ta9nglprln>i^=*9aZua51A+20-d>8qw<0jp{{>rfucK%6Fn zcb{1fUE3bz==*!#=th0b+3n;KC+)HE4+SH(pKWOo7dh= zM$_B2cMkG7OPwA5;nZ|&N8G;0tbV)oeC?^~MOjvL)Ng)2Jvj2R-|AJTiAb?ks23$d zs^!j3n9+r$U)^CxaZZA*>>1Z~vxh)tnaT8EPN(6`$;LR{ciN)S z#xfS-CBk+AzhyGtIlfcn5pg$E*P*^FFyWpu@BC4RdzkNYhEqUe$*KlRhXzmUT4d*M z58kc~H$=h%V$z1&DKlKCedLwTeOi(){rQ*f_8V6UNz<1O(?-U}lw)-oziZyIPlE6s zbZHXEv{7ZA1$cbJ_#8=>AY&&SL}~^G%>S3r@TBh6QJ-N_#ld3t`i@92vG6hzKnU!k zvbxmvp6~DTo??HZWS32ZNC}N6G4}uU%VYQ`*F364L`)AOZDM3xb?qtbv0ZT5M~hDk zXyHZp65$R@=%!&w#l$5Yh(Chd<{~_QFM@kJyL^d=gdmzt*kz?h0J$$aTlT=E!5%W8 zgyS`KTp*^*5-K|8O{-jsel^D`Y{fZbH=W?2rvW-+bMX_f~q%4udr<%5oQ}9yk$bLhvyWm zajDD_6g@uEql&F)-fotFQfFn#LG=_k<|%Do=B0byd8^=<_%E9k0lHDzn9l3~q=&ll zM319bniuLJivMz5fywn0li~qTBuPm^sRInaRrY~ z5IcWTc#^np=-i=-AodhJZk7tPmH{*VLqa(2F^Ad>G8R=(;FCLn4B8j>ZG}qgOunb| zj8B7fqKdy9NI6-LAHSu#RdH@W8|WcP;~_~4ViVl5Aj9-o0@ z@@p%fK`k6%L|lfeLI4MEYP(O^?ahhzmq|GoYz^XPtjKS7<*>)m_QT=A|omp*z9R)Ytz9>hm3}iR!1anP^M_%6#(aw zbKKJBYmrIoRQEW*hed<51ip zO#ZkO-1sm5nSX%DT;g&4DEDl?5+ug8qa?l#uMinbPk32lLQBGdiKaU>?xu6S~Bn$dRO4l4pX~*1dX=m^fl36EAhZuAnFi zkW4OnfkpeVVY3xw77j_THBM-1y=q539M7ACx4++YSIdq%x zPcR}qK=uT>oRJX%GM7g^?dBLn%d#KYr;g`_AH=JC;Vo6;77Wf+GQcAxZ7*SCe*EZS z)`0~;GWq^2&TLFPZL7fh6-@&myDk8~e)ALRfN>LL9GC%(3d+mh!2sj>?BDb~Obu^F zc%*G9uqFtq4qL}x5E<5gXo1sI3+_h#;wI^tJP_qb3JN@T@ZWuOYuEtsBJS|{+ST%f z8=P9KH=9n`=aE_$Af6#OcaWvL!RtUFBM)|7G(63tOG@^#Z_DVLAk?bTy>3bZg(Ozf8jtYA{lp{hU5Lc+sD^OGF{|5>>=yas9c6G5+7r z)0ox09}3%Pqj5g8+}@ojXUtwt#*{fM?7f(zzgF;_r|tu7PG-zf7g#o0Im^|x(7a?c)4SN<8;EvPi@%w8sCz*5;2`tWebQRX(MQz8n7*O&ZZ zx61%Y{cM_?qr=8O|BfZz?C3N~4-&L3W3&yYzu>lTM^~;t;_;#T49}GKOGMN$`JG@h z+~QR>JAPvR(%@VRr=taLlmCMa{yy9@(#jr&!WWMo_!X(pV(;=&Q2V1=POssC5Aqg6 ziJEuS56T|88?r{%{hWdNg)>{3lM2Gwrf>;24ar=02LDBtlS&@uNojm@=9iOo5C#X2az zPL|c&q8@#_;td=BzkjH?{ug#T`G^C{A|qx6ps zd64wm0U85OqjjE{_fF@t%1fV)S!L5+xy9J__UrC*Z7j5L@5Rr}yfk9_W7fWr&qa|f z+JjZ)42$ZcbiS_610oMxBshQ0mvleUR4xo<{#)E=7xkHvwz%%-sR+67X%pPu#=ww= zhOfeHSI{f`3;f+p}d$ymh4Ih4K?6*2(EY?MRc=-L9QkQ1Q zTmR1mFoGMF9T!2C@b*h+?q#8$1q;j=SR&3_lqc94St2wapUQkTN*xsVCH6Vo|0mBI z*R_7j(3RvY)w%KR=5NR)*S|f?s(3dG1mRPZK{?x&#TG*CxDskWdoQXgLuCEar%y4X zMC=8LP6cNNQ!-?ZK?+7f6M!*A^D3U?Q>E2O+3LG1QV=gD<3kmT%*W}LyLx+>i$CHp z2O9Ta?^mb|RM(c<$>6td`hv40A^cvke(GVLUA+EU1M$zcM4aVNM6Tf8Mf9900U@Z; zTGs!Y*TcjFQyz-XU73d{CD(Hn_O?hAK6`N_;1|2!?rLA1yg4}zM)yvw&1TLr??-Z+cre<0{8&<=kBBDfj^v`)ZNiaS zE8U{WLKD4BmCP_DZTz*26fR69zVA;ik%WN)U`^KQK{#|~gEB2hSx0lnE!nOvbj30<*{^l#O5+7X`YS zntmIoD5O0JhcjuoFdO7X!z>g7aV{o8cf?6ULFPi>3kj&IrUJv0@tw8BF)RjKV)!Fq zQv^|+MSb68Idb1>xR=6b(<$D1wL zNV)*vWCYDcA@4Q);WmjC0|M*TZ-3{%0}ptr#W5RHuA=ht0O1vYWnj{p-K`rCupiUC zmR;315Ch{5A%@K3PsTkKBmo0Nt1(pBD5GDL3?J(ML427Vp78IhG1jMm9)tKsAV_VG z>qiALW3}TVa!B$L)xx3x&=f`QyV7~!7is7jaHhLqYTN}zBI;szxC&gQ+aL(!6o4pX z5tQ+=6$8@=RK%d2bRDaI7s2gZCXM?QXWzDhZ~d0^XL2YjAb!r_L!oLn%Up%XLa zJAAL|)=OjDOiYMhe(~u1^r;QFEx|Q$r+{P)8mHlhC*X&XSXmDko_pYan5%!XDI+KR z_03XeVB?8uL{OIvwJnK}dlR8U>v{#vT?&27yWPe!{9@IG@3x>O{~RxSU~zbo7ED%oZxe~vW#&_AB%1-cGED$@b%B(Z0Ao3Tivct+2(;2t$4)^HPfxU+pzi{_ zK2hWM#9JQ}Kg=2SVEjO!HIyZ0n{0C4#^UAlM41FyE}3*(gliT?$qkbW#Jvg6^Hlya zv=F4B-M=z>0Ph+E)FWVGWWsu#EJI>{4%S^U2#t6>4geel`cHOAm;RDakSGu|NA9s} z1z}b~Ym6Qn#v8~rBFSzTBLL=1BFYu88fEp-B{`?+n4*xy3jp{R%Z4yh-QJ8!Z@oDO zH8W|K@nCvhmrs!qE)gQg7p~o=34ahQOiv?PJkwT5hU--|1g5>rAtN!m7NFc&%Lwum}A;YKmvWRz+?)CdAbl z))|S96ch)(4mS?T@*?qPX#bN{b*js)=sM!F;AWXLLv;n@c& zw%>;%q)c?y3r>itEls_4;^)0Otm69Y03W>^eW8BG9;vXA?-O*r>RyaXqDnub+!8NX zZEK7;ciixC4eo0pQ-^)Acl@RUhi@BU}~ zf=YHn!_gux9)1nCuq}#50%^)5Z**4e2xFDI8oyz%j6&vCY4Sa$v!!8O;S95DTXZ27-4S#s=hEF&j*jd7&XKM_e=_wv)wbhB)_rQ#se&~H0e{-6wlWYyQ zkizTGn!_(-`l73Cnz(eoF6Ieb?JF1dd$)Avlu&SjY^fyEAis2flqJ8dX4+-_;IF;u zS$z*rSucdU_xQgK=TTVqy0T5?J$#k3{v23(#_Ol7Ss`o~ofpo`D#kxN<vnsyDmKSlI@3{vPV{+Pa^db1JQX68X+c+tfz&q zxvAGD5s~=*~Qu^k*3wQhl^^5s27{XEt#Vx$8=-%rNt>cIwcX~meAZ+ z6)s%CtFDB)(A3u*lZ^Gg+H`ckH;eJwwHo=2h3U_S{@Ii! zuZB!1>jGublry4X?XS<@^~iDO2hJxpG>4v;cJ6mQ>RfA@dhols=F8kliAQR*>Xhn+ z{NYpAT?7RFUhpVaEM57R_%m=c75u@;u6|02 zY{DVzM>QA(H5YRmqwfbErU?)7Nu%zRI{An$`f$W9lYR7ibgb)z{r%tQggb_9t#vW& z;4vS6xDZ5RUK7OGQO^?l;dLE%RZ!34Q)OIy;rpJLQ(I|xF|PeoE?jsa-tvoqa^s=6 zmsIC9jmqjL2<4fFXKJ(2|`xFJA15f1n?~MNTU=}P+=xWnjv|lVPPS~?(rKkVx zWungRb~-e2!ZTYYSliST`#r_%RsK+Zddw%0r*ytTi$e?B$aEz2K>?#%cqW`LWN7h!nBQM$$Ayw~7ELTv zq^~fAh7U#T&grfm3Li)i$Imcqua8Txv9ZA@DjPm<@7^7~wBKeOHzu|Sb|;fvGI>R}Nt$e2$IbK|?zfSe zxpS;8h#DYAF&N&U%v{6Z1iz7pDlIJ!!;%$yG;Q7ZTz}fwO; YY=>%#;t&UkDy^F zcuAcIe?XUe$`*5nd%eMe!X#g=^CMVmNRWVSz`kG2LMSgFWOGSMnxFRq3z3ZRND&7D z&dTDIHF+|?M~*4%bW5R4z=(z*GbrXrQU)G!For3xakV?1e1e4S!gQ(*A`w=aARIA- zm|m^lv&yfcGWgr`&jV1$>HNW+1Ko_2a)7^x+~IF>r{lZlLttBBsRNP(4II#?F7$MS zJO8b}R1~%rndcH*p45O=4tT9UK}}-k&M&BXVAM$bGZ4Uzsgwu|DbUJeE=Se^;4N~- zmapBot!BHt?!?G~AHf~b3;9mY0(b_*93EKiZmhd5=ornV} z2!s3=;1v7^)}lm1{^nOyoOEdb^YFaUy|j|^9N>Vyo*rXlI#CDUPb%>_1a!bqvSxa9 z4u|f@g=)+knIIBAgVVnJIA@~&VMr(@dkZUiYwg3I1F-uCn61=*(G_DqBz4%~$^YJJ zyL`*cSx7)&zS*`M9Yma9sT^v&JKk>Rj8n zEm|n+qpLgyhNtmfEUo^N5t(opJ)9gN^Hb*KnmNhN- zuNvfkqSIaH zs~z&Y&yLe}GBx}@w)(bJoY?x{UzBv^09WWCFxa_RspDuq^HMG(syD0? zxch0h_lo_Uz+0!SOFyOGGZmUnHDa1GH)3L{n4OD0#Z+yq)x`hof13L4aID`pZXr}; zR!WL&rEEf0HYtkGlI*B#Ss7(!m#u*`j53l2sVFOvy+RVQQ)IuN>-W6xalHNcJU{Ec zzxQ>W=Vxk4tw>4Qnr`KLrO7uNy+L-?Qh{&t9{CFbxBOJscj03UwDW1GP zbYcC92aR^S*Ok99PIrg6p6%%}ct2&SN~`*Ffgw>Pw^msw`u#8CFZY}DXI^Fr|l*4sgHfpoRppLs3i-;Gi0ec^s?EWiJ6drcxeBXxr2G=oIg z_LG-8C4;8+Ur=Ls?``p@)ob~UJ)b=N`akA5e@>+~rnpgWxV7=~-0x7v=Hmj!naSVu z)3pP~xY#KtU966?H1WlF#&(Y%7cKJ=J!Hf3@YUmM2k17;d<`jPx$9oEURfdvw9%y| zxyfY8pSv!t&FuCGiM8r%dojP@*9lJ}`eF`wy+x^0~SZZsDIAh*3 z@wZ3$a`j%uGmKe@veRS0?^)N=4MHEbe^3S8g*sawGeuO%tykDL{iG;ya!t*o-L=q= z)-fQc;cV1Vb@wsX!B75~C9BjE-L~4A&(|i8I!5!oHa0Z{9@6!2*iS28SgksFJTY+QhRD{fo5ZTHVdF~4li3=)yU5cEJPCcs43(Z{SF3u7K6 zc5a9b$4+}i*5M?hzbRirN=j|ZAd(JeE%V83 z3ArfXc=q$oUzrNR-J;3e!*D4t+m5ukk;+Swq=k8lXfsG?5~9brx!0*1`~%acF1cd} z6Q|8c<9)JdgEE7^2PsWIas58e+5fKI`5dH#c%t}#cfj)B3@2+j|K&^>e20GcqM4~`NuVBRnyzej+E z{;#ta`GSfkH0uVldX@(V>SY7eQ+TbVqFZd*<^E2E9NA6%cHt$}hJ)v*Ig6iNisLgj(yXrN?N<5Cdm#UgxLJ`(Zx55m zSSJ6%`Q8xv*>`3Qx0seGgDql2zlB{Z4r|~DwtIHTwtswooL5nz$xHo-K(6~oHa5lk zp4@VZp>vJ3BALRgg{x48W>H(QQrnS!kRg6qE5&lrGE_0(*kjd+)6JqLwo7S`0(cH< zOed=yyW+yuo5NF_5;e#qnn82_h}iOl^lBk!rRn4kt_W{zy~U;b@?FF{r2vU+GQK|$ zDf%ks%+Hb@v**nlRX+(c*q?RaESONq zkKDYLrIv3oQOF4X-J$b9$M(6uO=wWdV@0(W%9e>bv%dwsOs0JFq>2pRxkZ^qW{tNr zs!j+C^NO0)k9&)qsk9A989qhlo0&u%Y?+c*XGa+*m_)7h`mv*6|Lev8k7rb6Rb0hJ zDhGRFZCNF!qg@87*}@Vo7VkUGl$*u5Z6sLVQj^{{gQI3I7vBbfso}5EuFanWM-zA5 z3{^4D4D?cy*uW(-lA>}YGVparKxWSMTl`$?)*QwDSDoyi>rcE56zp!5?c?GJ*Zy8g z?>a!;dV&T%wT5W6r&>&%*(~z8gMv&>^D)c}>ugH2a*YLkpUJrv?z?&2AqV4rwfUTu z<)ZavHL`vzR5wMlUGH@(k9;*9)QeMz-tm*ayJkK6lhL{#^^%uebdsaCwhZ$A`XRYN zbH_hTxyRS|t~VugB&g^yMd@BS&mt9gPWew#oQavFAIpXe{}I*-zNEg}FvfBBrEH)vRq&rlSkFP3$*=~u& z+Uk^WxGERZNXw}ou3bLjeFa&&5tNSYs|P0_VjD|yr?He*P;ZEjk2f%Cm45iSzP@tQ zwX2CMr@}5q94?47PM=Qk+wVX4FxcwSJ~~xeZG4coOWHQBxr-00`EqNsglBYc zX#&eb-%7E>cI@FhH#g&VaXllz64qg)rNp=Iy*W+|WD(pYk-lWbMW}HYl9K?*`?rs& z$MFIha8R`XlTHj$C8sy-eEJb*1(<=mG3LjNV4>9yD_$(P1~-m$fz5U~C8RVuO+;v9 z;s}gMO45RBZuP?I?zW^;;s;PZg6x0~G&~|KmA_ebV6BA0jt`-lRNw~+v;zU(V2z2}7FF6T5PE2J;~)>AF-|NZbs;6;h&*6nWd)i_?%t3ODx&P;%3cT4^`RgB9(p<_ z?$gwKCh@dX^1JS1^#89`93VALHRC+9DkR}(;0un45w#42LL1EMsGr}BZ&P$MdWAQ_ zz`(%i>-A#wQnE{~9fC5J$QzgBjfYH${R}H4N%}>KfjN9a?i^irzx?Ve>1?DRv5x?B zVR8eI>#!i9%YT5K{x^iuF2Jhe;1|M4OuT5lG5|dwha%mkY5D>rCAL0X)#e$_WP# z&vTHPL&LpmII~qcR<-@FY)t!v zdv^_8f{<(X7x)grb?h@OybhcA9lu38z#%Z%y@3h6#3zpLs2%4igIltO@9i^h7oR`l zp{`z~{glCgZQ`KS-!gauWn|jtBnzUt{U+1(O^gKE4mkmdA(|TgguJn5^v!wUcaE1` zGjrUJD2stB{a0j@NUZZ2_v95$-EYo70&xFFfJ_kcU37(zzX8$!DD5GZcyo9uPqEXd!i*B(kC14Qd^Ha0#N|0~bGRY0Uo#%b$gP1EA#{UD9x&ca}_Q zq|QNc_1^{KToaWuXE@DuOnyQu($QgiG5}6gLqitmZ!t6;KYqOVyu%iCZZWY~-2-JU zXZNBOOeulxe|IYD+JTQ>W{n+cCn9wNxzj|S#bT+<7yd5 z)Oz_b;`C-TRnr*ZUpy?I%Bkqt|u&uxbNS=;VpVzWddRb)?>% z(w)ryJN?QcNiBk|K`;4=*7GnyAj z`wa!I#;~a0+87zE&K-DeFWoo&{aJ21ZyGCIXZ798MU6W(jW^2ue%q*;apkc+I|6h5 zmWnmgpP5ePdb4Zvhlzn!+Bm<{Esh?_re3jM+iryU^=aRZZ2Gx&&0$jIv=W_owq`hG z*tXQOJH3ijnt#(ZmDxY!Z)sc9E`Jxa<{;pev(M;5MvTzWmc@O0weA_Q3$_Y3>=E#y ze_wavX0CCCG|l85TEn>OyRKIUBz3-Y467Fo$O`Gn{?Wo6*`M%e;r3kX?pDgxSp7Fu z`g@j)-cNmO?D>?FLg(M8u68AL3%f>={@#6l`{)^XUue}DonoLU`Qd=19P}LZ_(Zca>zw1DC{IOaiEc%~9>btt-BtPz` z)K_Fxjvl3!QP&ol`#w5W8JioxihO`3D*K&5drEC@Q?tgzt*5?gC%xZuef-VlcS=1s zZV9t|+rr~@?9ilUbBS-;ciSbI1sCi8rvC5<=LkR}M%6FFXxf(#eMC7Y-(ca`n}O;uc5b|Jj}(u1~* z5ekLkY$DpMBPQ?6*KKOM+@-m6bGkVUWnvsHqz>9VPytp{V2ZGq=+}0Oq}hAcDVpw z4ar1<&taAX&w&4)CT>cw^0@c#Es-3;wDCp^Z%u<%wA!BP!kZa_aA!`IFI~b>4LA+K zLM^+dO1^D%T=#~x=1sai7JITIp|}m6_DgrI6#EdU2-NSzH?O^*Q=t-V$Y-*XpYICE zA5*I9q3Xb%4Cm^dO>HkllmPg{==|)njfnY2akldLZ{7siMbHO)9yj3^cNcw`z#E2^ zn{PCiA*W?XT7pI8*@j?}qYT0Un!aXPA+Eht3OgdjAwfivh#sg|4rzye!1G96>GFT@ zW%vS27lbj#B8FWO&Rcfk^asv~7Xh^V?h0O>&ASwAsq)&&8{s0ONR46mpU4+)7DD0P zyy;2k!P`%ywn%~5PEdVVm#2GNzQ8~hC97Rj(*FyvGnA0rh<8K)N@c#uCYm5iD=Yjs z#)?0?@DTCx@ujnfS!=%q5>WsFU8z>`sxN+T3Y=m zLZEy$UyVJED8{guu5)-KQ*w>zi$<*!ifhpy*N3WX?cF&T&c(%mXP{REQf-%(n9w~@G9PLy8IHI} zyqt=@iyh?x37Rp_rH@JjXBoKCaY@tK+B(F^9w2$(rv1-0*p78(z^3m{c5$G{}BmY)4e?I(xb3Ab#46{lo_ z41?fMg-?3}I#536S1)x4TITdJbPTEP&x_dT_=hMa#0ZrG2P zS*N31nZ9vkUQ%HavAidhR$;mLNtHIUz^H4}@%L4syJ#sH*8kC9+HvY!MKd*PC_{k3 zWs`Fs=(d0Aev;Dpqu9%4;~CMS{G+tRhz<&QtyNwdl*W@W-uyVFS=IAYZqRUX%>JZZ zl(&V=;satCYTCWRIKOVlSIRfNGrdV>!-X`uJ+^K%W|2-^5+|REYrU+|OcdLnWl4AW z*g^NB5^@@+m#F<6-rb_gX$&1VoSF0ZULC7d-h6nXFKC4Q-vOtA-J@z(0-co)cRqH1 zw2IK=S7}opO?l6c^y|mX7`>*;3r{xUny6;Uobx=e{cIHL%_jl+sZKf+Jl6MGQ~FCY zbJV)hH=Iz~!DZ}No&}iRFxxBkD|L%mx=GHL5MOI#b=5U3c@OGfUTfrqkQx&*t$2p3PzX#`;0NA}oo{j&b499ZuyB6*6D7 z->=`Fr52#|SjXY+R@z#Q$QaSi8#ym-GK*f7yH0-rlR5l*Wns-9)E0;=I8_yM=d3rWzdVpP}x>$#DTMEp-jv2}Px-oy^qh9-LKw zD#hftW2fh)4I$h#$(6!1-x*m_*?Nq?e96{U*?IcoWJ&SehGwaNkGtCaLa_mU6IV6ILDhCI#5x{4&6IZ)#E$Nz;)btK_1(6_wv7C zdgSvRITGpEr2w-fu>>t4mn0b5F_DsiCQtL8d`Xn;B#nLmn=iE<<>$Bu5Ja*g?>M5l zV^s))8bHZgw<HvoK)8525Fjy9J>AyZdXRwj$wh)ipJvjj%Z*0;nGpbehMvtAtlo zSp7GM5$2H7$`$A*EZAwVIm5FdzPbOoRwjuA#GeC`4N;T=+P(ap?mgHOL>N{>ghp^- z*!OLK`}-^mB~lQZ#2>`q0y6?mVnsFYIEd>$G*uHk!~nS5cZiXM=LELoX-@?VJ5KO> zptfi^08l!N1|L5#D%Hcsf7xI5S}ar$9Ho88*@a=*9F|MiJn9net@Fgs0nnq;EZH6L zo0{qjegZt2Qg_f>R9 zP4wX&ZO-3ehNNZ+fD>C_%)m@X#FV5*FarG?U2h3FPHs~q&P?@v!Jq6r z8=SlbF@Si!+{nZ0`|5TgrjX?Caw-DxBEc#|C5%}Uk;pP+GREV1MvU@j&vCXJq@I{i z8b~ew<}-T)&ue_?lxeFq+zoJ%{U&{X(2{mjzXxIVDU_K9)(N~L_$C*=oVmCXxkF0$ z(&W&OA6M`n4Z6REloE#fy8Q35t$wnwdc$)BE+_=y%_xW{5kn*7l02ru3th!_h+_nV zfbnp>Ii7oP=*W|*AXW!}3+JbflD7$u;-?9Bf^30&*xfEbf#IEghh!+hfU>f9NzDe{ zCWlnNAng&t?_r+&`r6|%Nxvc1WD@JSB#W{SGPxq>hIrAC3i&O-;ZSFJ2Sn6Fj0trt z-lQ3*A3A5(fHL)y7Pg_m62`HE!}zIW48zn!vNDNN=Ty`brVbnY`rsSLyo4B{uKR0# z`*MFZVY`Tc+1>TkH+YVWNwTN3V^<~K21cx7B+K)E8BkD8zJzy&%={$0qTc>7?JiG$ z2x1812(+!y4MS8l_5h6N2p}qe@C5lewgNgNYJwySRm{}thVFJ5Zgu_0Pd&IbLRWqm6uaQ0O7>H zy2heJ&vTJRgQ2mizTs1Q!)s;Q8lgD9qm~meD$`*JzJzVaU_g z%zMbQ&hT++a2@4{!LhlE{2hLB&c-`r4f(!K{TY}X6b(?wG_t63ObN4DsX6*{LQz6q z$fEI$uugBP%Gb*?O+%F1WIvDFyt4~z4yLG)d8z%TR?9xylmrM5_B> zR$xv}Q1LnM*jc`2SJ`uWrXFYgySPR9vY72Muil3c?iMP!cMKb&=yC5W&-Jl}PR!q@ zrLWw-C;!f#Qu<;V+q0bNovT#YElMqlRTWD0>R-29$~tj8<74m9Y+EBUp9g^y1;XjC zWCE$n!<93X>ppEPf2XgoVihNFb^3Pk4#S^ebbgJ$xSs6~H@xrGCH;ku8IP#4WW+Ro1|sG{cYks&F8!~ zbRSn~?p4c}1)JU5J&Jc&vorlu(CPJ7`O_B8&STrpG0Z#k)`zZ8#WPiBB}64&T9#SE zvY2uz&w7-dtJc~`YiWpwcT@4c6VnWmWfm+4oCie2hH`f<{Pb$)pihxeJH{au|AsoA zT4&qn?@cm>8CtF{bo%=YTwk8~8{&57bigak!=>I@vaf4gU$)%1_9&u*nW0Ibzj_<9 z&pldtvvNba%aK)e<>?Mu<}IqdpUxXz8&YrBTC=K4*)_HNh(XTqi9gR#>Kf(C*B2QN zv!AIi7jLxtC#d?hb-zN8q>D(T-2g}MCkkIP(|ntwfSFDjZ!+&Us=vynsHnui>2o;E^XI{1ISfB6*NL_6 z>l7O{rtb1iu;sscbEAfYz=1=aEZlUIpwwa7OvIrE8DKrcckggCpi~izMt7$R?uP%? z$~QG68JTnx!|nSLUB;zIZ^3!_C@+zo;8!Gb8i%H1#bk(h9fr3OxppOY5o1eBOXAJx z`YJ#QdP{MeK0*Wpgk3PUP(8!hRKECYZ?~m8&LuK{5R(XTRp9c)?*miDvz$ZaOOx`u zyvN0s*o}F4dCAW@l`#bUpbf%!09qgv2r!Q#sZM}bKvG&-9F9+%n3gujPHk^$KjXm( zK<@biUs%2?Ah5GIsknShsaR&yA&j|@@Ml9UhNNicPn;fC>O@{JZ4Q~ZIa|VCtE36! zWd<4n(!Puq5HbkB377@NJNiMPIq^jHv%^9-dFjZ&gP1=mMuiGk4`6-FhHDyMex99% z=C z8fS_n*H&*&iU$Al?^2MF7Mh*%u_9Xt)J?JA2}d)38+uh74k2zBPZCkg z>yzuUP65gp1yF({8V{9h1+Z~J$1<+Bh=$Q%z!uGuYbCMxyU7dh_@Y+>P$CZx)IA19 zw=sw}HL&eR^#oa6B(C9_ySSuO%#c`Dg|~>@y;?dt78sC7rVe>X;YlK*3P_ER)$?pV z7a$G^X@V?K|GZjsh+sHu^2be0x6_WGat$L9i8u;A55$wCpkuy-{Ls?W8Xycvw0J0E z@Nkm&B~VgqY-}C#DHxtZAT~|L=p39O-x9e( zb%_UF9FoAhREn(MQ5dd~@MsAK+4N^mw#&swJb__*I4Ik0r=m|1I81T<@J~-l z3g@=`FVcxoVVVLUmebNzFh0azO$4Lx*AnjxNmwSM7^pl5sMj^O+adW6tgJ`rZiM?F z;k+92;6$w;6xZk3iH-SzZsMJ-2Q7`*$P<8P6XfC-JT%pqO3)g{3kFFghz7z`LYN|2 z3$jq;7k?m$0VK2s*7rKh(vDyD7_cGhh|8SvbHa-wr&Prd--#m!yM(5npCTS3_|4bM zG2%G2CIh|3W<=We@C9hR`42}Tq&Q-#;GcPIZ-s}C?52F!l8US6pr9t8VE(IL!bqB_ zfUkt^hAK7&sFd_9EMQv_;ruby^{GZ-p1j+rIdjI~?2JcE0+YYSP4;*A3_{P6b?U_> z3Ibcg81>}1b57Apo^NP|7OZH1xYG0IaR_d$+cpdUj*p%G`(LlC^SvGdKt?utk zPt&CMF=+iEG(D)4lK*@VjV=EScUnGg()qKg{oC|n%C^`_i$-5I_vfJY$We>ZRd!g< zsD5`tR%-IMJgf0M?Onl^Nrj@6_cr~UIr`~v9mTr~f%i4)ZEwek%Pq>$C2gO{8*bza z`sneVOZmHrjA6Ea&JS6x%a!ykALKq?%%uJnrrc#$%%BXXMxbq+6nm1|;iWoT&Qo=J z(fPI2@Nqf^ukWkfHjEo=rrLfRA5qS0v|E-`?JpBl^?Mz(fs0qA?k8RTYLE#;)fvCX!hRwU#$yj$OtEOdkZsVT_&B7qoqA1&9?~7(n3C{a;~`CykbW`L;bNY zklH>_r)oE)rDB}_Yb{lssz2Fc^)rUi+a)iDe*Jg$QTkfd;~O!QM>c&3JwweB{3=F5 zc|XO=#@sfs44Q%qa-Yws)o+Y$v$?>1^4>r^mj&g&2KPM=!eo~s=DZBuFXA!aekN{T zyK2WD@axxs0j+pZ$#^xUcI*+TC+_8MQA%sS?uFAJf8-3td#Y>fDA; z_VCh`hV%Y#oqgj)b$!0_(W$SLVKZttvI>tL-o#$Iq$3+a?VXmfB(+0v6Lat5s??H; z{nfEow$)glv+J*p*AMeqEzml8bQm&3!X_YtAPmDrGxBBSL6k)lK+{j+TZjvfgDp2N zZ-~U(V`Pwlk^-d@A-H%@3xK5%{+ew4bR93RS9x#qo(SW{Zu@&_?Y*P6fr%k=g6Z?W zsdN1KSKe5>UyS6UgGGX@wSXUvg7Sf$J`2SkCHqc-r+`3#oLblhyAI?Q?wf5q%ov-V zPHcN*I4hD0ddEd%lbF$XSi)SZ@9FC4Swf0L`X0y@jJxl@S|w05*>*mV=TMZWxDu>P z!SmUQ_mSpC$w(NHIUdndQ@ZoNd@!~P%H!$=6Xr0rV@9o$0N2=v>5y#8txx@|g1&=i zgf!(hdOt+%TLd{QR}Q|Xbh+#GONfl0n5Hyv*O9#6b~oIdps{RIH0h| zb1Wtrl!e}{Y6V)+h}m#yH9AR>C2&jHT$_SBga}$}?p&og)9xsgg>TrMI z5~z;&auD%NSVcsFoR4lsGs9KrNMd<~HmgOVg$((?fNRzn8FfNui|NXGr+<(1z93ZG zohDf^#1{%H4n>loGNM*Wit_VZOclk!1ib+?MsiF*S8qZ9Vcmv&gj|CXhLR=X$rFt0 z9p-nWeolP8b-&f-S?f(%&|SCpo^-}kF6Q!653>i%+TrnW3xFm_9MUm1z9>4m))?N$ zL5Tr;Zdqfm%G%Jtg=rT!6mVf8J_b`tfGWIr-4%zDvS{5TjVLIXTea0xPP{{~3+U|q z+S=O43|3X81gb=;k;#n-xlGkWDh1}iL2yurCYdZKN=LpC-Ybx?h|5N(6Oq+OZ=Z9y z!FQomao@Rpc!_rhl<+s2f1A>X+F64HY6r>{G!{zW%H|Xf$@T6cNZ91tH9ufKTKKG` zJKPi8h?cKIB&^3@M1|((GKh4NV@>ci2%_=FjC=n@t;ByKWM}8!c6n|&j3li|&A<{6 zI!3E5>fkM;H3q4e0FHw24!{(*bMq|TNA5atz;YyosOXDg_oexP4huPHA^dI_0V%8-TB&c`zijP2f?Ym2tWfED9BoP!(Vr4HQVDygwc*#s^v!qT^ zq&T%)F6zm5zvd)4G~M6!iX3;1D6anenb?kq?GU0lqN983w=(X8MPZ7>s6$a_7gDBx zR|!LR9JB_+LMV9z5q%=~@D19`+&jqnpKECCZ7|f&ec8s0S=6yNll^>xIN8I6R4%Bm zaGc+vxH-9!CW-4IMZTZ7^Zkw{Rth~XAxoxZUiY5%ffFA@i-ryj*M(|-3K(GZm3i|{ z(zj$C?|M|v&MR*}8Aa73bg#r#@b->(9&FV=Wkap}i?mR1C0>+v)OG(mN%p8qje6e) zsHet*Eq2VvErwTa{-yTT=;*nPxkicKq?A<~R%3(c%qzJ{jslSw5ROFd54R8OY+;H&a zy3mQHBbOO-oIeQl@7kZrd+dR}Wmz!y>c5>57X8&sP7L-^Mc?cK+XON$M}`KZ@eXhf zRwVD?siiFN*Z*UFv)a1#kFDL3A`6w~ZswC=Xc)1)9p09)CP?9^rlP-fX0u0sZA@&> z9xpx34Yq=P7q2Tts_-juR<{LMUtzrM^e4_~kIS$Hu$|E^Rb|&9Cms^w@n4Qx3 zxOwT+%8d=8@09LlsN$_cJMiiC;u#+#=}!J~%{cezSkB+Ua#0$@B+lE0Xg9cjRm@x; zShvOH>4HR&ghBIN=fJ{RGxPK)(lV-a}wSM1k02JXbrQWsu&s+O(k)ihg%yrrUi zs6=LX=NRhchGyXCFot z5=^)`;UQMV; zOcNYO(}@Waa|RVNfE-fE3drFj%^%XrIzI(0<=eZlzwR5k(sTsFRiTZUSb$J!s*_g> zND1yykcsgqVuO+wCw9e#}=)U3?EgYSRMVv;&`;zT80Y4mJp_8cUW zDp=ou3_%~gRZ?(kmI7;+z=YL&6UY=mzn_& zTX^8MU=>Y+T-)62Ji8R{2CH5Ab9VcOd8O|D)a{=2c8DK;K5ynd<$|i@JZ?ETxl*sc z!YeB)^z`(Ny}d`IQ-?6wpDy#-@}fKP!vKhHgykJReE84YTuMH(|1KlY;(xq0fy0D~ zs0aW68Ya$e8iF_txp`X?>;Bdm-V0?r0+II*d>>d1UhH#A$1At^Ck$i$Ge8x+O}Dgo zJ(uQ0hQHh;$$VBaKOhN4MU~B$U0}wtG{=XW4?>IOx%^M;#pS*>#?~qi{0>NG6hQF# zj~dpqZ#$iTZ3sBKf6Uz3IZliBJqG(fzGen&;NXFERSwTOt*!dvM+F51ffzfX^rOse z44uwz#)G|i@Hpdv!E{SPpzhq+)!USczhdDx9~=Y1&G940_a_d5$%2M(GkE^}i5Dy^ z0K8cjs&@MdEw+Zr)5J#ZL8IPjx<7>O{P6@LHme!&;HG+y4`+c`O{ zEG@IF9g!XUc-!mAsVVTJ@t>_17XR*L+@;VDusiti)87jUImUSwwzi=)HQK(B%dvmY zo)#r&Fv`)sW5TO&kPPd%OGX!b1n8(KktSDPPyh7X+n*3y0Bp;xHMpb450x7%=3o2z zqU(nCr0dTv%-T%$z5RmS%+>p|r8piZfb9t>v)z|o%np2@MVMx$W;@%)_5A$&vQFQr z35B1N6QO#)qNRoHz}YwbwYT?LF>en%AzX#1Z>${>qpm41@rg=HTaS2(9d4l_5uLDw z;eEr;sO=JF+vv0K`^Lb)c^VoTKwX%>6?P5dP{e`0xVT8qc@p0M*_mX6<<+*Zw7f2J z`P*{L&8_$s5$Q}DnlW_b+@gR?{_}|<=B-sW#g%WJc{QAeMo*uZ;EM5#Jnoj5cmI~U z)g#;Iv|rucFLpIU*OlcP2Um_MdwB z;eqSs15M%QzL?%+@#^JJIPt*zi7Jg$zQx5NJ7%T7e{QuaMsqsoGPFg@Ifn)QsIDrp zA1URoluH$yE6eOCq2DL;tGjX`+dAOZ)|Y!S%BQ#AylchF;UCpoX;YpyK*vBUp4cl@#zM`RgV)tTr{ls!xQWq@6MWOQ`Fj&}O}>b)-G zN);;G=Lgi*ofo~!f#gw)D<^9b_4r|_6u+HWyz8d&pFaG`>rd!*Z+rFh){$kU+d4+i z?OLLqJgKxOIM{ji=+Q3+UWy2(#Up?9n4qO6*X+!orgB8??*$*1s8ylXf{IHoz;+yA z)Zin5CzNI_w>}L%?!FV+^GI*T$~Eca%itmBm*QvKTwOs2)Q)TII4xe%+QM zktY^jKR7WdChGlCC5{&_SK&ShF)@R=WHax9=sla))E(mtaXuqIB-lmw^sYz69ugcZ zj<>riPcIr=yu^O!-163&b4mmK+cxWFzS2PYyP~50{bNZ*y*hD~|NeQl{!7k#sv$1@ zN;-65a&`eHK=6uKmkT0a?jNg`{8d@s5ZPvuIZVrg)BXOjM|v7X&wIdi&{6Bnn1{CnB{{Q2_*9YP-d6oOT%MKoUq=gig3 zQWtbGg&yro(R^%Fl97MiCr^;0xZLM+M)KgZW8G$Y)n>O$9#iqoy-4o#64Q93O!NNQ z+;>jrxE-Ham%Lf*p7^oFuCfYkd9Zkflf!m&`?Y{Oy2q!NZ)P;-1sZhaotsHFR?Z5N z^<9)i&sfFOyo9*;`;HFw=H_M+=K-0_!0ZBgis%^`2VLb--4A<}@A;{b)1{VE*RGc$ z<#Fn2{V7vvTfOoduau-z_l~__J|uPg#irIMvxVdb2HuVucYg*QDlWVvvTxtMt(2{o z-?=}V7n#1{5#VuIrSX?s;g8hak}f;9+MkMGSpP$4ed?!Rh0GHV0wuCIe_Rh=Hd0wV z`$l=1O0Mv#L$TDICuQ7Qe{F(iKIQ2JZu|V6a^KX+$xFC;O`%kRroHj_fl!|=)OlKB z_8Q2jjtDb;5n`P9G+FZEKC}ORQIpdbe?Pr<`%^}%TX*e+`BSz-ep14xBowlwm%`-~*a9uySX>iTN=s^)?J2kIPkr~m)} literal 0 HcmV?d00001 diff --git a/graphcast/docs/project.png b/graphcast/docs/project.png new file mode 100644 index 0000000000000000000000000000000000000000..99ca505101cb697e3914b416489297090a41538a GIT binary patch literal 156767 zcmZs@cRZH=`#*e=ot20r^U4lMR(5vwUJ)THD|=^VZ?Z#n$POV%cJ^M$&PrDHef0VM zKHuMcKOXn>Pj8pYd7bBL9LICL-k~Z=G6Z;(cnE?J$jM5oAqbWbf?WB6gAVUJ(qD;# zf3Ccgkki1y!I@rA{sBM6d1&Z3sTsS{+B@2rTiTe>I=S1M(Tc06+(QtY6wRo|cx_zx zl1Cy^QCufFwNdK+*U|HE>o0R0w>Y>1DDJ%u8JZK>5*b@==-+ayt3F#w9T`2wd~5O= zjZQ^5L3dJiPw#ETz~(LLZUaGi_K7 zb5B01$2!?%yrlGC%5Xe5MBkOB_Uz{rpXdzh<@wKsjbeh!lb&Y7X*PWlvR>LaM3xvO zwnT+>|FfjR_u=7%nUF^WG69{muj0NJXx#a*y1zbLpzNBpw`D|28;At73Nb`P`8BqM z%;Mm$O#W~aT@!Ygi2VGNivk}jp!F$N2$uH*0sNT7C5Rq%aNC7}>&^2}PuCB&LGa7} ze8$U+j;Ilh0)OTDhe$Fk_0Rp6S#7px|M}y8zZc+HOXIVncDV(gMv#54zs)Njru|xl z)W3dNa5#(m^51W2E*{8MS^M{^rwHP=`nSy%9c#~~^`(%Hobi8e#z^;rC++`U`9C+X zl5z#xq~mF(s>G`D{MqT{#s~dZtzKp6=f1wSA9Hq5^|Nfw#N~ijNZ}fkDC-m}%*d|5qeS$!)Oj#`SR!MGLsof80 zFrN1?vNf0^qyw~a8N1pp?diizcy3(9cvPf8DdslO_kX{hs5OTSYfqudcK)}U_nkXZ zD{&^jZt5&N<5GT&V=NS_qh~m}UVTub6_!I-aZiZz{rkzBK^@iDK?5}f;;(eP|92Qj z{ool^Cfm0socNzK$uwx-wHw!A2Dn?1rbOhe#C4_AKS*oJZayT|uQb<>{m3?1x#aG2V6F-H=K7IQ{GpVnt4CFPZEAx z>lXdsTK@3t*vxCP|BhU*tr(T~PnG{XPsA5|e;@w;yqPKR!Qy|H_y7E<`^smSt^d#BftC0FbFq^Tw|)oF#ba*c zkV?gTmq7gd=|;Gj^Kdi%T#flbAi<1&?OYM@M}5!<{okA7qUe227)YB+Y9bfAEQ$E( zDRJW$X{(0JmT_X~R8M8OVfEVKedlt*kW|IMy@8G-aDRDAOMz>_o;St*E*-{g8y7pM z{&UX#6*NR+gKa8OF=Uqhe}@h3xp`?qc+ygzm2@@Wo=AKUf|SOob$8~=yK84ieD^|& zcuEKZw$+xZ{7H}@FwP|Da^%nyi}5}&^5LfcqRY?Mi9&jNn|OKjq~p2YV@q!$h}3Xa zkVT?W0eNy#ZUsMSyPHt(R?Yglo1{;<+{}EzO^p!+nSjX zAa94Qd*mMzRouyXaP{g{p-cO}wKX-@FnqX8{&xa;pUcuSD6to8-$m8Y8(bXfK#k}B zG&$J*b@#^duK6h4gWut3H%bpp+;l&3@3B;bWrwN>FGP!~yhjjrY~gEvF9LDs)=Iry zoJ;x>Hu$hs$#00(Snjo6mK3w<>8D!1A6?73nUs;5IyvFw;!;af2G8-lY@P}O5prEw zO4g~?Gcfq%vPIe9aID|Fm$5@-5)bo3eyx1GvATMH=4Jnumfzw>sA!pj&hi||zL^AV z|7V3@h30*&!+sPVrEno#CH^z|w?7UVvi^Pz@3k*QgqT>mg%-8i{T9TFI)|*M^4NNX zDbScARPs-?)pR5~QHRaX19xgfEKck$L$ShAs9>eSB^PA^8%DD5m%v8~rf&_Uw8*`e z-)m!t>(ou}L@OVkqlMAmWP#o#PC^ojO6zmflksa`AL_KA?Mm7jQg;PNz^X%-)0NFBXS_OSBE_SN(BVTj@*dGOa8qDA>Pd z=AlZ;$H#|_jXi30aA@S7sa!D3qj=T-$AO3F#oxvI&%aynq&95?62qP-Dk|#h>$BmV zo}6?)`@1V+MYi&ifdb#K#iyybxVW+L^yM2r;i|y41?%y$#N=ceVGj+~L~56pgxGb{ z-aB)`^#&J1k2!C<=))pd(Ovy+K*gdJ5y)k=RyF&J@_w!stbnBJb_!ot$L_cCbK?B& zZh8}l@#Fw6?5t7_?-SC7`w9`e!(=wpp>ahdx@-UB=ij8hge^8uY6U^L&9@VbWg*UHmpV?DAvF!yu+%CzEAsc?1HPB?OVz!D!M;2^U zy;7xuVcWWC8(y}=tOdR>dw#0$cE?K;js!xXNM9ABB?*en0bFk8k zoVZhZS5{VrhSd0}zMFF6x2-t|#&?ZOO%;z?5j8p^2r)huL#!-g{A9~oL1@R~vuDrb z^BiUFvt{+W?#wmS*4Dx}0TLj>$5&8P%;0x0v2$jijwT`?s5jt<@6s?dEE`^@uF-tc zPY|_ez>zGS`-+&_MW68c2oZP6rlq3AHGf)D?wg^K^s%y;-mB*djElQ&BE#BuVL^Ba zOw4o}(nMcDA}CKR0xGKvZAXMg3@LDrYo1MIVl9sjPE_Ab?8RZYD)%_H`s;6gnh+9} z!CNb8H&scq)s!qX6I36m_XqP|lf!9gAbH1ZU#mW=diuud^D3QQ-mv99%R$jdz6PtP z^&o33+GE4l=w&4oKl5A~bQ zjgHG<@_8PMk$f)i$99fd&JU*vLvlZU{F`-ooJCV}vA({3H0Np8k!9$8FnV$E@8|i; zi=* zY_#xD_T94kq)Lpny-O?=Yo;5P(_FcNWZkFg=ZFvp+CDe(H+4%j8YR(I=rAhiYreyX|J3N%3GGCM?=;& zEv)TcO!f)A`}BUqlw00{aynY*Au-0T6Q>M6mEn^ocjPMb^4hIxZ0kOJnBIbQtcW6s zHeN@&aZ&iB5-^iHV7?Ujq#|M4f+)#Op8J#>-(=QdCTY zS8}plOG^yDp_W##5cVAsIVq|3?(Xi9k&)xQRo9a15)_fr?wRFfWgM)mU&6>%N)!Y$ zT==ONprt&pI2upg8=Rab@~KyCZCQid|&&z15a zTExZf@cdz{N_q%!=gscITm}car5@XpdJDC`(pgHmf4O@rMTiEW9K-AMCI5i>EAr2Q zkA7f6dmWX{u=E(>@XwJo9r}|0;r9_bY$?Kbw`uR;4&X}< zOLbZ!^?Uv@fqsj|!B;Y0;~>nVvGQ0ds8(E}717?gRLVfVRm|w?@6`R0VwMvzq|j2a zJUK3|hkhLzg3IPM9qbof2+NaT&^??GG$6*vTxE#0V9h}HV}6jVLMk=vbiobVL!1~v zNTt-L6{xFp2@%;i)7Coq%vJZyBey2Cq7e(O7}7iEc#_KlRyCueqe=xY8yNzW3Q9^# zF|)qbbziOn%ZYa7cS_lwZ9MvNm6P<%vv_YcklY1kNY+)mcxHt zo|IoUEGn<@c8cyF-hHjZeghxth~Ts9;pXJokB6|wUw9{5;1vfE+s;>> zuaQcfugQoCK0PH;14vL>S}KdGJgT=@-JLPS@pn6H(i2gS>oP@=SXx@%OHmMzdXv8n z)uFvzDyyG7Qu<|1B6FBf|4XWly-x~^xpkH&vcE|%UM=#6N=m>O#H*0D+wg{GItuky zv6rY7sSprLWUf|CviI+)6;(g&pjwN4%_`o0$X^3b&+Kb#bUF0C(r*m6XQy(j0d5oX^rU0&DiUIQY&>_uRAsoDaH7EY zD9+TUZ!PB$j_Q4tMKpv+Ld-GUfJ25HT{9~6Zlb*FKINlE25Whzeg7Ld=#w^Dq*5sg z!C^02ZZx*~qD!eZ*1PB2m$)Lwk*rQCr8;uhar+l@lG&gNCq+yl(!ZCYVBUe=_-gx_ zD(TJOt|wQWnjBrP7aNdDg-ZH&(`MOfh*Dw+v?TjXr-dRrLS8Z!)IlgVNOF<+{el!M;rE;?1zq0!hLMbQhq|hK9v7G!#f4d zd!L%K)9K{sW{VhKoxi_3ikzfQ1bt9n1xILx;x9w?HiLo@ z_8)owlLQtS9!xFopM{FQrKvd{G!_@;b31;alLV#Y_tC0r*7EC|r}_)#j33)EJE5p0 z%!z4g$mu;S{FWsA-P+dY8bX`gqpzPzG(d2AJ@VTM`HkcxZNQsT^sxhX?zCrYz0h>O zHj}J9mY@IJBK4PropoV3Tj5fh=o6)a`-x;p@#jp`=`?|{vZf{`R_Rfh=@2ak- zsHm-VS?Y{E%aX_*`1)1S*}10w=m&=^iX<&94e?u930Y)Hl)uZ#x#85%-7OO}NkbTN zn@~nmGtsb|TxxLRAXtJfy#42{8*CHUlZuK22Kk`uh5@ z*W|ekaAgp?3naTqHg z8vtYpRpw2v9Hks7%M0!ZRM)ePs9m9~$AwT#`pnbby!n0XMJmNY8U_d$pk8`&Tjl1r zZ{J2nNW&tfR5f&UbzNOurKB($IK4#XPj_1E{&masp&^p~z19A#^78T!;;F5}&`{jy z4u&v##o-&_dC)eht9g8H(%=OFrrvb%4>^LefM1gEe=3T?v(eY$)@Fd zdwXz~Kn3|eQv_G&u*3W;E-vbJ14FgU0hB(eMLmfc%v>D)n> z;F_8T)$|W2@S&JGJ3Cugvf_N%+}vDSv#_%(pWcGLD44N*cDy(JL^z^y*0)3wAPca#ts(kU!L_enzfMi*ai^G4 zuyOv=L7&KBNH-eHCva_W%}pOau4Wi?c_%3si}^Qvrvg?v{CMNhDmpW%=X2!;nMy-Thx6>uEfk{|Uye(((Xle!!TVJnRq_IMH zJGz63iAh^a>*p?i8y6WF8LZ;t7sMC~rFC<#>iM02X>2cif=T7#is>_-_uA_-FajXk z|NVPTX9RxR$oM!66$O4Uw6NV>2Ogr7tSqA^PaaT4&Yz6c4QK-5%gD?m#6XvIErBJ! zxyjDP_OSEx9&Gaeyn-a?#EB__fq$<5`o&EdxrULPlq4-9V{K&>78b_H#1zm<+rI}( zW}@O5ET{U|bXkz)K`+ z)^>c|$#<9BSDaXqJ0%?!entj>)U{@5K1`RG4$xCE(}IYRIE0a)G->B@rZ$P&nuC{k zng8P2ds`_SVR1raK$oeU?U^hT=+1XbJss6AH!M{HTPd=PvkjyX1mjxpq69z3yCHuA z=_0ass*mLE5KUvD;dBJKsmYp4;$%l;E9mK!c3W@n1UJKP^n`owl%#GxaBA&I_oJmEmvEp?O*E~#TTeK9SzPNDx$6^?Lc-?XU#>>SR%GU{H#7)lu5ubR z8_??Y_4a->shFFab31F5y&BSH*L0B7zj}Jy@0*$SVXnpZa{gj#p2Ohd@7W!nqgY8c zoCR31oIbnZzGof2Le9TPI~-_y&vqkk@kLPyaNoOkExet>@RN?Z`r_WP&tBTRg0k{; zj91MFROXhGUnQ_$waW}(X8^^Xv^hLKnopL`6Fpy7c2!8Iz3opsVpU_n5iOld{6!Y^ zC*0@^-CgJ|J3B?>H{X*2Un^T%mwE%Z!?8bqDm7T8Ty21fKIUu!T2F|;{uI636HSy~ zCi4CJXvq^)D2^u$JUl!J@wYrYJzEY+?dTlPGX6~4wY;dB9I*lgSCg6a{QMjkZNv1| z#g?6K_>EUCjpy^1=kxtpqQ1L7+XJla0fd{FpbK1X6G*B4<5_}%KBEB7u>#OS6(b`f z>sucYOFmz13;Dzx2DKLmCtlD`AiuJc>F|sli{u;zv~p^s?e+Ebo=4l3HB&1q=E26W z523{KPuoE-F0jr%FRBO?i|Pc1Ah#TN>1kdx#3yTP?wh05K8sG~c2y1Vy1&ZAlcsiGPh z8vKnO%gbA&!q7b&v*e{_w1~!-I#IV(z$%LGxqfc7I*s`W}u;#>erEk zN-8T8aj*R3=SWeAij2%Q^Kx+5jbeOSq`{h@Btw5KQhIo7EVu!FJJ;;(X=aw!+&pK> zUBTX1S?PFwa-d$I1ic_PZwD3zNhro2{fKrGrGmY&Qhmg4Y)s1^TpFk%(16yR8gg^5 zbkRUVOXf7%-r2E(M>#tf^^fbb*E8L+eg-X>Mb@yo!B#W+vjhnVH!#=#9_6r&s8e78=9z zVTf1P)>`qTiqU?kuXl%$NjCDiDNJ#MnEg^-FZ{t|8zMN%kXjK$I52VJl2V>AUM5=PO;ja&ngU}%e!fL&tzIcsz$;pu zBs2ICG{x?o9@qwMnXdd)Fp4Yg%_Sw)^UdIYz3KN?NyjFVsHBojyk~{7E}K0FdP0lD z$cd5q$x0AQzIiF3(fL~Yq$VA4VhzioLjrMH!{TqJhdn%}>IvmhB(+?R(vva@N5AQU zi*Ddh;krVSNlc@RODrMqXf%o2`$h3oC3bws{49RKkbg}=w%mtJQx=j>$8kmZ8VzJC z(l_F+9xAZkf#xL#!r_uzd}WRF1=me_jE&~aocqs?KCoyd3`m-V?&sxBvF^QU=6E_- zOl;(oh0ad2U4HfObD76LkV~=7RWa5S=w2CA0aVzsVP)1aI;HRCIypJb_%^vd+`*{j z9v4X%Wu`gz59(1LKS9 z?ZpkRZMfnQKQZ*E@~g2U7nFBfYW8S<|rVJM7vsBd3@U;nrY%I)?l@=ejlxRJg{YK z+Aumht7@eGI1VLFQc_a88(P#=lv3b~*U$an!GrRqL%oXx%Da&yp@GBeOYY5YKh4?# zpn73v=R0yGG)D)hP?zcUd`nA3#jE^J;D0aFSdX*c2p_eABET_a!JF38(*tmhk&*Fx zxtFH6L}em8U;+-F9s`&1hTrdsx@8^^71k?VxN{v(ThbvNJ0u3_dKK@lwuBD+d zz|*(qB`4Q4TA~fp;@BNUi(_7O3|+I6(c#65Eoc%jVNi;v9b#x7Dp}^#+$*Gd!eTU}<67RD)aS!Z*FcT=| zo;yu-)z$xegHfw3m&W-WJ1Z+eSQHZ3Kiqx1y}h44ed_Lh6SIpjNJ4h@?>D|?;|d*% z)D?aZ;C~I}J-&bc{+&B_7~?IY9n?Vi0l@I$#S8p4+USC=t}dX1&CSg~v{3eGeJfzT z%{4fel$Ay0uLD<3xHy1WSsSwky?&?Z2>4}5aWOmqZ6Ne4s~SoIB1DXzikh0*(8!1g z@(5Cvd{l(c5t&(8B^Zd27+DmM{wsaBiDU*sAR7#gA;DWWaHCSL9Yk;grN4jw zLMLg?(Ttj!ngWFvh6+bQQL%5iCmCkb%G&xRUpR9Z5H}{}Oi3jx9z`NTLg*BuyB?10vwAk6qKrLmgr>F^F;YOfdfyu44SN*m0s$aQ^SU6P0}t#J(09Yvuy1Qa}j$icJ7zdwA*hYdT!>F zOJO1(PCxb>V=_odVZq}%WDEMk2sUkOf2(9vC&P|;R6+qTW7P(@lQJAKZ?D&lHg z<@#F`3LhXy5wm@+>wUiRFB)-;zRGIjk`(w4_r-XxA_GYG%=KLbpAdhUo6C$k ztW6J%l+I;gVQKL3{=9fVFiXk1=KRp-*L*-|7t1>xVb>Ln$B(Zc$n~A%HVmolc_xc^ zeM$1A6Z?jPc?+ODC_S{q6NVybeP!?8gGVA{#UX@~4W)8>JM+z(k6y=r{!x}{Z%|cU zzTEo~bk)2{U{#$qRu;g!w6yeqw(GyWZfIZtYF^{)Eb*71prE=r4>sRt0GDjr^$~Gm z--}h>i?hA{T)DfIg0ZqFa!x~@v_8ZyI2dc`7RZi+LrSGuk5MGy?UnkphqnQe*|nV4 zDll4u&=NTZy28I+B_J9VEplREY`pNZJ$xRjSqzVgzWyzIEI^ji!^82X!RZW?keZ);F4w!(xOMLypNZY9m?1+6o;ALj(+~^*Ri;J>@_?*Y*lm88Os3H zoq%xX;x2g9P%6{O6Qy&D%F5yWvi!D zlb76m%*~~>3Ut~RG+3UX3K#l3MqH=*qV4_yDXogEoFj-_bx3LW%UVknY9tC}2UdSzu& zX?gnCx8dk?lH`LsN1j5NFyGMPoSZ@>)igCnXJ$Tq{J7>=Pc2t%La_plX#K*gnPp%v zF*{6%Uk-E0Xv2pOKU^By7CsDGIMo|~h)`b6x!O@WybioaMq1jx^{KAzO>yEicE{FG ze5|&V`_JJU4xP`xn^#dsGi4}&bh%gWGi<^0o{=F9}wyD1U zruVbAmJQX_)%EqqU{(QKSQCh@g^UD`tDr1-K67zN2^-%B0<0^QAsJc!l-X5&!cZ0< z;{v6UrCP2oE+WMi%Lg7{Tf;p&9zc&D$=Y>GU;H?0Ge5xh#vno|NW$Ug*U*y_4}b^4 z8YXD+d4KFK5A6n<3*U8_a;NoiB(o9YEQM=<1VTVSP#B?BWMg6RVbZ2ZL!KBTH6OZ0 z<)jTXa5*{Lh5^dR{QP`#s8TysTFj&%_pLQ>;XmU|JMi$PCFXbkV_Md~@#|OFlg6Yz z^H>Wqe5~>DacDt+h6dXp%eY?;+sMYz@lO!BVdv=l=xAd3i~tp3YCf=1)n$u!DSa!@ zUXpu8Cni3+THk{_NY&dZb}5=F0W?SSgdOs)EU7=xkY&a$k-x|RFIasO^yi)y zyeMjqjF6L|d7g-($$T-z1HV~0m0@~zki*tWz;Vn?7G?I#-aS$ zl6M37k6Z6X#f5;baLaZAiKuvts8KXO-_jML!s_OGG2}24vI&h_KlviIV}b! z_27P85rSk>T((?vtVTOtjdF6%q)V(xaIY~c&{ns(F_#`eq|~APBt=0>NH(mk|5?*D zydX%%43m^XlYY*s7mIo)OZyr@J}lxe;7Twsg^dJdr1gH*$b3jGD%!`%iH_g|sd=nW zb++HiwNxvKKgduzuKTJYrt{>+rnQFiUD?T36ju<${g$QR6;yqw2&Pkq94$`T!c;Z#P^$_tc+yMinrlBh$Zw9lpL4^s7+hu05VJh<$L=>1S8b)^OwEFpWtGRzce|LZFUSLcw$r?V49nv@ zwcNqiU!+M%rCT$pH7n4O*E!5M5+F&@T>K7+jg9?hy>})o0t{qL`~}5~_{*93+Ybz2 z=`ZdkBqb>+DLGP&PfpgSO-3cu@^A*uBObI>teiRI#s8AM4! zKd^`MgwIG9X`DRwJ$!x~$&Ql}{nINzx%rU|cBt0FldIEPhZ?M-Q&SDzqmi4z{865$ zK~~DhN_K1Bv^^pj;y|qupanpHL=7hn-^@?b$#EN9fP)5`Jx`6yG)rt-in#);vA(r6 zW^NaN)X*02+m`5Ec7-ATx3d_neLDpuBXTH&ordg^xC`u|7i|PW9eSCcU>eVZN@!{?C$vv8^qft4gXp=TK+TB=j zOf=up13nHm##7)|wc%`;2lEaH6c+_4#y=yeCBlHicReqJju0dla&mI$v$w+A-)og9 z6~xQuVH15$-pD#@=`+7CP7HdfYZa+!kJKbDibPKQ4KDzaL5qPGoqJx~DJYzr&-v@B z0eAlh(L_pPiZgRjK~eulV~-|OVI{FA0oMS65iF4n{UuU5;HTB?rH^a(;*<)`eSG11 zD3ToJoHucpI8DBo;$m^SD3aJN(|T9kvOia9>5l3^Z&e-L!v@#_Z(FP&2ndG6 z5;b}moTVffux{BC<>6XB%b4YvrKNYna-jKKyu@#x<$OK$FBF<}D~X=v(e-uqncrsP zvnhvq?Y6CH+`7Aebnp=|^+5UN%l&*LKKQ~M)L8n8h7`F<|DG)8<0|crLiLQWiyH)$ zzx5m{viX%`B^8Sl3UcT$)NoZbdm}gtp2@Y|{818i$NTQD+yhpBKki^s3}%U(Vo5=6 z8B_Or>1QL2x%KjM^Enc7xVlO+`_G5pGsml*Z?WP^0S>^AV<3=k3?Y+@IsQ5jlJmNj z)jc)}3mI_yViPmK)4In+yl}Vu)Dy>t`Y~asWR41Rrq)1wzo0e06bXU3W9waXB%il^ zI5NoSOoRLPRP61Otu4>9UO>_@j}bqWvn8^g{fsHYH(X?@ckY-lBX2wJ_QobI`w^hO z!SrJ^9ThWF{}@n~_4|yb6PvCDGsq?k^ZFrn?yBh(O?x6@j6$9mDXSs!SG?iE0$mDZLDVn zwmmh?+W&sl;O!Cq)yvRUY0+)PMY(@++XP>q-OLDc5bxMwN#-rT!%qx6+h30LDm?pM zP}HCO#qwAF!(+*l3gM2&$s8bsu0r7G1k_;wUjP>XsCIYDU&QI>4pIqw>`&ELGpqgG z-{%SZJyLYr_6CNODmbW~%Zb{5_$cv%uawc_`2YRx+T6Bx23#Dl2_(?cQd0p3z>CKD zwtxQAkL&ki7-XU#ze<0Ubh`(*tSl-=gZ1i-ND$kl*rzz|gFS|Z*v>XQ5IDAfjfIZp zPy5NxPEIa#en%#g@tDaCY)cuzSa0!G6~G+HJ4~pXdq$Bc)mHx;G93e_T-l!rK|V%`UjRO%i+>{>kk8Btc|5yGBI5TqyO^qivL z%a)UkGSk!dPtTdHC@K#?$B2{9>m;~|e3{AvIU289Em^6+_3Di%s;*Hh=Tgt8B@}bASotEj_tJ@(@eo3cFO$DWLL=z7bqCjbJiZvv z->uZWmq{XtBPaGci&xzk^=|1pjdW5_Xas^V^9{ubb$qkS9K-dOB?$~S-WxWqIz} z(Di$t_xxxB{Y1|%Xxz%N?WqvIrxXqOQMOjO*1uxfenbHRIQF`561?+U*-2L*s{aE! zB9=Qj*({x{5r@GNX2PB5@h)}nEFy-?=aKSMD@3ubxD;yQBpHrApYe^h78FTKz;6X}Ip4T_2jARUF<2 zj1l9F9i}lfKOeoQvkWPXN;bgD^@*P(vMp?FIapYZA)X3yWB^GF-FaS9lkjR|0z>jc zXq=v)<-nc>-N~^&J%_n&`g~A^W@{EJVgB4wmuFRs+{2T%>nYdT(J9bF7cYifWCA58 z64wWj+q}ZtanJtc#SC86cRcfZ-wPsH=yj&r%hVoG9X&lOE{ZP;T!J#-`czan?5_<* zMG?lK1`lTEWcO)jwYMj4Ep3Kw8o0d2&VcaOahaNcli=-N(D&!2r>7?-DDZ+*hQMA? zEf_woDffP{)9=%%u1p-*L{rRscILuzoj?>dKzrHUbmz)DKZw@kx}115g8~@R=5VsV zew*e18|`3@TJUPvNlTM4^FwSaA;V2TT)|pEs{64nsl$_9rx1QoQ=vZ7^6s zvx{d^wsq~iOG_CE&hHS9H;|>cI1;CBNUU7Hem!~wE7%yaJBL@#b^SEPx7aWfzP{XH z!j)2GHaU0cI^C_LdjNQ9*3Cw28n@J)_Pv~&%ZLQ|PKpbQO z&w-C|@oE%4=V!HC$<~y$804`QCq2|_C^@5Wa0(#99MT@*n7f= zC2%Ht7oC!s!zMSs$4lX~Z>l$7^s@tjg_eMMUn!yVeM?JAS(&Z7yTH6+QC^-Ik$%Ay zHGu8Y(~W@WxVgDSMau^*Tny;Z5jZFC>(?&>j`1Hqo`Pb(!q;-Fdus>GY|t_?GpR{K zE(*Z<1{EBl9fZ{TG;)=IWATFMA)GY4?$yGy^gK|D6mz_=x|&pA?L`X}^HQx= z$$>``=AP#}cAHDJ^FW~42x2?moFShFA~;^|De1-~3B-LN>ftRiA9R%*&q+FAiYN~# z7Q}L+J94#3T!)`S8k&}vnUiy8bF;VW#%Zc0F`V#XwM-?%3OJso3xsf)gbs8~w-u7L z*(bztG!mdhz>0|)!bSWR`7>8lP5J_8fCB+`Y?^EFw?@Km0B$n9T^mqmfBT|<@f zUHDg>AnE;Em>>FE3=PZqGL#suUB@hh(AZq5-Stoipa84zXaZRR3<==zpqlsC)-}#< zH?`B6N_RULphzYcH!$%^gA9Hg}Vmc8Xs z%;h=*(ItAFO`Mq7W!jcJ)tCY$iG}{k>#@E#F%(Kj-?ToRP|bksWrdCc5I`reVKr4% z2PY<8qgTM82+}j3XdlRX@49`05L#XyW?~@mb{fCqL(0f+>rS@g??tbI!o+dwS#xz5 z{aXlJ?;IL3$f=c#=A@a-3tjnl%+EbXQ>}b|ontn%%u~nib)yjF?Ud+SJR~K$luBV# zOoCpMDiP!&70(<`s`jFKLvJj%KC9pq&<){ZNRGMfj8_^*kMlMmK9Y>PN}StHxhj2) zz`;gqq>}{>Yl&y9tJod250@0OGbxq5%YR{sD|L2&sU91Z&-|f2@%*wZ$VG7T5sIf( zRt^yh_7*&pYuuGo|IEpjCzw;QB;Kze80fw#GX8Xm&VbUF`0uB@^dVLpG2(Z};>>}) zouyliuLoZf{7_BiWxrOQTizzIvf*?B%vl`EYV1uTs)&=c`dZy~LlVopD$J zXQQI(7eBJr$`~Z&ShCfwi1`VNawR1v-^3RbCk9lS(YW-WKUN_xCgqz8i0^uN=t#uu z{G_^_^+qfDx2GdlK4{}KK4nI|CMws6?F!79mwHxlyRS$R=a}jKb&duO^l#!hUoS^uOWm$X{S|~W!VosFjiEs8gcC=*( z{QWKRmQ~G^1&^(f4)8!=zg@9?bOKEYj#Bu$f zC(an+u-NLq#tQX!m{iOZ`VoKc?+-W!yC|+_xvHw_xHf*>>7?ms4p6y=4b9E!Ar7}d z$W}wS7+(LwAh;S}aI9kL7eG@|ScubJ1{txQ7o-?yvpCYZyK)h45Jaz3t5gdv4LjR^ z%tMHJ{|KTVbMC7dVB72~7iqu=G#l;|h}4l-#E|gZzaKxJC)BU3q{K=Zmd=%}RkFCt zk0*_V_z}a7^d`S#&WZ0>6p!5`hHxK5VUjtX(Bu3If*Wr@6ysRA*oJq*u|DSTHV($W zl+jF+Cpm#sSPq>1@W5Cup4}4f;EL!n&9U_G@)BzCYg_2#)G{*S#6b_6N}2|%scnHv zF0RuYj(0KOk_4%BXZ&+tye==im+(B6yAykoIdiq>+B8XDu`xN{UV4?jVTyo#2a+F2 zsA(utw`ymytJ0wp5?H1lTHjZ`dCQ3LUDTVAaULgJ;Vdm1J99O93I!T@<%%c z6)M3!+u|$QLJhcNyK!gArq+Ox;I}1{LCgH*LI{n|?^bLJ&N!4Rzc!9B1k#_`MSmA} zv1jjD`K5p&d8mYSh$UPOwwU81atR6?jLf7}NfGlZ-GX6EbhNrtW5-{Rxu~dUbx3bF znS3@6ay~4C`udbCVu4uMkB4a)a7e57xuSdQzDZW?CkQpr#*=v)`?SVh{I=cyZsTRc zvL0>kkAsW;*3vVTjFUhnUyG)dfSGtfw3(*GOSG2kLW6|2G%ijK7xp@tE_D*c9mJ0v z`)2BUmc*5ds?@v(d1_6m@@Cl&dvVciiW44aabpqWX{tE}PjPw)#XQDMj6g>|2BgP4 zPkGRse$rOj`P+uOPb4FVj8MMv3O5VC*5hDx`&rKXs^3F_#Pa@Qj%wv6BhCQ`>i*Mq zwz+h2H?$ax!u#=~`PgPS!DuA7KllghTl*D<3HpA-6BVfZ2&%c$T8|4kLOCrBG{o-< z`cGC`#-~n!F}i+-IJAz-*FRqtV2>6nUvs>Vyv<8bb%iD-s<=yFJ`^)vi@Ivt`*LB) z@WyPQ0va=kz2C^%n^g{3z3En_4YW61_6&7_ZVE2j=-D#@5@806|4!Ms+OfwQ+Y!xs zUu;^IO2_v+OH>9k_EuJe)UyJm!^8yAGW!I)ARyx>wV-LbzeS6TZWvu>g&iq zLu!a$m#J8{1aD9FkK6TMIFbv*0Ot@0;~O6si0L`hG%^z9=B})*g-k&V-REJ!6G+g(e%ym)3AG3eBUMtca=>>|$`Y9! z82HfS1%@Zc@WCM=kUidA>;UtWi^2;`Y1%+Y9z@>^7I4{|09%XG_qd1V><6|-lSr)W zr^|~|sE|Ovy>rWo@6A3K{BIVZ7>^7Y5+ZpFfaI1J3bI%@THr`|*uYs~QMA zQBy;Nd4-?jIbm0&%_}q|5n#%YAOSmx#(Q5CB4BM?(6qh0s1OxZ)!hk`YU$i;OE_Q% zL8Ej_n88Djrp_m}9Zg;$G?nyT?^KzRd!RSBE&TrR10uH&>d0Z{;N;|FXNLp_c--$Fi*J4NApUthaP#rN+@`@vglYjAiD>>9`VPZbr{+7BSQ=_5q#dvUzV z20=JBHUg}+lamICY>+vnm7`%n1RD=@&+i<~H+!cUzkyU*{aZA+)6y|e2p|KO@R(5I zV|~4VpkS&IU>JB4a+q+N1-ZG8=wsp7D-ar0QqYwBzN?muy12T|H+j0++kdF5)6P@7 zLlQbUIVp?!N8|7%z&ZgtAILu>{xn%NfnTSkJuVn#o+SKFFuQho>tDO2UN|c^YMC1! z58)VFT0(Viy?J>!LnuxRDFUyfSyxEz-66qkFeMCm24Q6|&?sIolX~mwf-5)Gm&ON| zMngJ#D&Mk(w*vwfy?_4-^44-{;{eX9Rj{XdW<22Kg`x{5dSAUlOV}7&-HVKh0;FZ& zp-Xx&3dxRH+eCRY9!f<{7Z z=whI&D*!2bdw%eaAY8V%`}gNh`fgEBYF#!`w%+uagFjbU=|8lHjsPBf{`|S<)j_L7^V@*< zAK24<$`aRnoYS>!28V{#*cGbiImM(U;!rAFrlI1FIZ{P&@h#1ic;DQR&pWok?$-`S z*iicyV^WHVy?t`5Nu*eei;gVWs)Hz8e(Ey*;i3px60!2U1n$aj>?~t)we&L+R$mHp zd!=!4#dzo<*clZB*Sl_D;?hW_JK8wBV&+VaW6}E5l0YgY8}pUg_NN-Iu5NGdgj!Jq zKTq0PnYs$bHKpG>Gk7nBgl3D9gz}b0;8<(6(3o-$+nfcDmZs*G(RLT}U@ZC6=S1D7 z(rhsb=?v_@0Sa(OHX z4M}-rt$>o$8jZobqfqcQF|T(5hgDizBUTzh)dPodvBvj=uT~K{Y&`K-CJvKFaG1@- zZS?TxF%V=(tg!LbI~>Hm^3|XAuTQDa5ipPRdx}g5=#u?a&^bKX$#z7V6!~ffYR#bT zL`YQzXrm$T24xFJxkyR0k|$1cx#H>364C`^tXyMwnZvAXZDb3E;??i_`2LtZ5`iVK3Aww3}Nm@NB|ttT>k9<&;NB!p%^WA{ikqbeL3eU)J%vf&!2C*YL^?H zcPaZ+c2usavT$> za3lj71o6?%dqcl488(xV@A(0d$lo7u?l1*HyC@R1qW!e_vz72mQBPOO8EF^lh?DfW zfeCUFgwNjSYsOD{wITK)3*N4(iOh!3Er;$*H=zeEcUYL}Mawo;9(&ho%H^^0acGnH z^382&q(`jIO+MdDiRGD+Ze$LnZZjGBg0&k7^2@o=A|5bU|GJm5@WDuO^L}2RT>j}mcVFKuq(y3_{R-*f+--1G8@!{na5KXI+= zdarC+S30kgJ>vlVP;o;yw=>YsAyW=6j)EMLPKZ_BH53pKIB&T{KKKYxp~1XGfHozq>FoI_if^2OYmH5NpIH&~q5FQ%lvt)ALzP zO@G0GPaqZ*ZUU07g|)SecrA>iv#)si0>yF=~`I7^Zy>q3#K1&(`|-yo(SnIuS+yc*Tf6 z0w91$ya7)0@bGZ3eZfq3aCkTrl>@#w$Y{wGw%KY~S~4QdprCGV9~qO31^*!^+GI_QKGxFm{Y*zEVA@DhNsiT(C~g%+7-QOEgC} z()bA-T@Dn$T<}!#UV`}(oX@cKh|($Y0;7w$3P)^%lsSnz2tgD9HI(13`9nXN-Qh3P8c@6DjsNP28j}edA(_{??K`zA zXqDG8SS5DO;n~)b!uFv_{PQfD>Ap$a zI+jx`gScs&s<=;@!x` zV5uSHn>|9q>aX--PwvI7*f@HwFn-1xIM?Uq{PFgVj(FzWb41h{!j=~&EfCvR(ytvK z`!;1e`?Sdyo9z5mlHcBZ+iG5|>vWlS_Lee|Gfi*>b_1e#${l$GzKp+7=`z zm{@Y_{7OiN6(@0P)VsEMd){Q)7)tz*TUQgGEBxpZRrpd=lm;d72j)m+6_u5xFKoii z+F&~O~45c_U>t@8>SFy7EL6?#?quYQ2U{yqwBzA(bFUO zA<5(<9CrMwN=i!V>WTC6Ylo*OUj^(j3SNl8Rt(H)fGFr)rt((n;ER zPnxu>EU1Q^2J9aSs9ml*r+O_#JqOzgEUi<>zZ#UhR94P9!gO2DN&ul2j9Gj5I9Bu! z$%Q%_3}Sk063=6Xb`9*46@FP9{P?kR9N4CxBAYX02Dt_fT16Nql9E7!$Qa56*X%3{ zm}rsdAHBivXaD=Bh;K{eI8mhdU06w;qbls__f=@CLb0JtswzMxx0N8`YsT-m*&SJJ zI#H+%nmlAn!6r@$d#fODj+~#Lk1Y&&HqEWuuPiUe#m5&`%7og2BUdFKf`)Pz^COXw zV;0&P8hd9K|1eU@rjnV!g8uf%U+XaGDT4Cg+v5R*Pvp@v6ZOi%!hT?K%x*uWcxv|& zelfn4@WMd|go3eQ{{(hAy`_99EHJh?ls+ldMgV9vAFYusZ`ecISA6!H7FaiAP6KF*+t@@ing* z7x?&y!)W2vPs9lD@c~~^HM4bmmN}FQJ1@h9HS}F)0rx!=2te$zLD zo;XY#PAZPX16PmIuP}Ga);Xy{3e&w2eilXxjgM)oCotwiMI)l?KDs~nB4u-W;dJrD zFHori7YZmK@O4cYK9a{fV!#<$qGr`c>H~AkKcSMrvf3UaDbFN<0r+4e=rKpA$0sL~ zCiIFuCxFigM^htJT&v}# z_@&!{wHV~7hu>-zqD6)37O4H~$MSCoDSZz%e*Gd&s%W0&?v0!WyO(*0=-Wt%JUu?g zT+RHkEEk@lnmDmwf+OO7teekGGZSpCvx~dp&1hszk9zL`SkTPQIdM- z^4cFdM~eupqmKLS4zi|^4U86 zny8Qj7Tw0-XS&0k*w*`T@wkAaqMnX5-HfL3aNJxP44Ys2fB;^^vEWg6^I%v>Nzm8e zJ<~K7j+$knv(2e8g6}V0*n6Cs%^LIw7JJRn1x8fQrgo7-`B6UY#PC(P$KL@gndh<( zb;B#UbuNW3yBhpRjMRky>jD6f;nOFVQ)w4d+Dk^W2%)qcl#dcLgT~!38>?yiyQNGc zOY*&KM&+)92el;S$3_{6qy9mn3q=4J0Z{?AsHu)TvM-X-i6x@T7feNngx8ODTXuw9 zE`L7mNic7^dK?o$_P76-+iYUT^Uv8T)~oI@zsm)V2`Y^#l`|9bR$@NQ?*yANn;puTzwle9X4 zIJ$w$a`9ASf1>I`n2-$`8kuaA#HXC&)6_z3l3Ieq$aVieUSy6Z$a_eMmPJ!1qS{|C z^BruZsIH4GzG3pF!ZCkDZz=jkrBv#M%6GEgaJK;sX`FVWU-x6ayVQGXKNFYo1o7!w z-j%P#@Zu+xf6*|ys-K8W^IGi@2Zng+TTz7KZ?!zVVHSjK(Dk-A6f_+)Mx(_U&F_Mf z@+CNSZy5VUO4TcWQcL#B;GmSqHz-{Xs~{3A0DVF|Uc2bySO9IKx-71|e_UQMM}yMv zp}K)n&pi29j8?B%Wy$os-PkAxq?ko4XrzZyggld{O{dcP_trVI1>0e#%*Hf z?hITU&%c*l42-?t%>@w0CqJ6|t!$GXj^MU;Kdu0Pt|E=Day(W02d?F*4WtPg%T7(@7{a{ zE7!5)+U3&Wdja|ZuP_FrXN;*f$|0y50C^ixDg_MzkWeWf+=dm;T=-DBLP=h}Xejg3Ona?JZVI$h*Lv`!NZ8?!CSSZJ^R8nChS+}C;Xem@3_ zpiSWkG2Kpz&OMh=tt^Abods;-aDrVO&Liyom6^dV4yP`ITFk3uyD!a#-m9V0npdcN zS1P$ibg~!Jtp~lE>g7QXhl_v2tGoulq{=rTCW5d>MSY_{JP#3JK#$Qh@ zqGyGF&k;SPq!>3uOJ|rVqvD`q$4~?z(_9DFS4QN`k=+o5qusDC*}R(t-B`%rTXk%DEitnbdmfs z6{hfTESk;~j=&O{7X#e0wG$0|kJ4&vAKy09ru=@*cgqZRw<#$p5VNOjsu^3i2a6l5 zV0!CYz6$~A`LL!qc50`6iX?b7)#nduh~8y)S`WI+s}JOdi$vT78uso6-3r6K{p>vveEpe-A`MtK>px2qN}EF_5|Rw@#H&hvc8z&ns7y z&j0%l(2A{Z`KbKgfAwa(KjNXy5KtQBhoCjzr1GOEH(va}-ue6fzaLzW-xu|F++Ma8 zd9(Z%SDJ64`?oPDeE-}Kl0k_1ywYOl-`9G}B{GQq);gHLZ#{9UfaL}niCL$RWrYSj z;D*ci)X7P~H2=Sk)RVz3zp9s{_^$QO(rCuaUF!GEzJb`KONZg%aw;mZ*2+Ym8NSle zW;Y0tpSS<}l*!cU0v0qhJ!6bE`^pMTtO)-*`*wD3aKVkN8Gmu>?I98o)M>Dtm6e!S zvvSI(kUsf2`dx;as^()3J_WT%Q_i&;BF+DOqep}RTu`J9wY!7*#9O@vJslXSx@u~B z&O`XLUvBDM(6Iig%($B0Q6DEJl3Pp>VVo$I;5M}UV;U$YaC>XLp`rU%gsMfPQIu=QYBkpBB~f&3mLj6mujXHrkSf1ZwJ#fbX; zb@zc9=Z88_ShS%fx$;)?^x#`Y-I^^ZaQ=6wQ2(8Dq%Z?N@PD_;2MCS-@2CC$`7uQ| z(X4dCo_jToiJ|E!C59(Y|DyWeb1Uz_Q;TKbJ|7b2|6UJ^nFSZ!+5tGZ2ud2sVHJWa zZ@0vvX{sgueUsmV3u^JRzc>KVi9{IIqh8dlg|JNw1vQ655s057PhVLJtZYAQjcJ9X zv#E8|SDNV_4%hgPYgD+iH(Ka8`4 z^@%gZ`M9QywKWudkNI;(%(;`QYH{_FsIVO%R|vPCLRHE`jBDm_MD&!g-~4+g+V?Y$ zmRbLXe-0GIgZ9t=z8=~Jy-SGg^C8B6j@VPgRpo#0e;bzII#*(Xb_aDG{Fe7s!~6Gs zy%%4mF3|tEbI|+ge|Mg#6`bK+FMzkca*LAT-$yoIP#s?0_~(%yz3Kn=$e&l5tsGEC z?F$`Biu>nAhLJQh!Vo1P5iR<6pp>c07Lg7p+S+hrx|+2NJVM>)-G9&aE}2(IcNP$W zw5dJcdM}Waz|-$H2&vxI)~njNLuL>^#YQ^Tc^DWV<6%8GWcDR5Pk|whL>q~QRy`;3 z<=>wK?f1{A(2rwKxBc^n*P~}&;I0Pvs}GNlNBm(F=QLa1b%t;Q_%)DV5-X~%CPaOU z$SZ#N|9)UJECxa^7=4Z(BA`<8?{~0*^A+vw8OHxT<{IjVHhWxaRsQ!OZv(sdZWSOX zy`NGzNb-`d@B+1TI02*W&xNO7B;c4>z}jU`Nhu2VCaVYx0BQsFwpCE6X6xM0G=Xz` zg=|l4?Ahh+zi(On{Md%#M$_TA=0^1s^>~}_)wbX*8+{lhAsS6{s)&8vAL)%9+_eZh zZk8q9a9S-n?QS*zdNB3U@+gC{Znr8Mpg&&HjP}-T)t;}9!vU0lL>(BDr zcSTIV8V@QPk&1l&80d5iaTdPlKle|7{qu3i3~KJ&W(0a85X6K+(rAqIO<3NDXM(*6 z^X}cOK8vqxp)Gpdy}ctFhvsxrlF^WY24QT>0tyzow!L6b9WpDgum5gQQMfana^U-) z7XW*^AehLy$5Xpel$Sw`I0!LkS_{JIy@%Y zB|fD3j8bqIoUHWqDY+gLLy~jKN>hytjMWvBMG`cTPL9*5<&T?v~-VT?vFSmPJfjLV1=qNi6emT z{XlpYvfgvfi(wdylU9{Y=ecK-&Rc&|g6w9w*(l{ z(Z|Q9$bFd+qBec+5goy0K2 zXd*yszQ+3O06kT^sCoHK13>Q5ye^+0a{(wtNB2Bu#HHZ}6#k)tg{q{X62t%--T$dA z?0Jw$UZYj$dkYJa8BQKmK(zrOHaKMzFdd3{T642m06GI9%`fuqY_DTvWOql)U=Be; zdT#xBQ9>Y6cWrngE*Qg(ysO~42&5-~=hO4Blb?$=k10`5a&eSLik z2|==mTUWgyBkgsP;57w42=Lv?fc`SHtUL4!uP{U4%DQNTaPK=jWOoUThii|nI(_if zvh*9*X5c#4!Ns=AD{H>Zbxqb1KSMDxeDUetR{-w!Hyjd^|MYexvd zdJm~i{G_xfDK#>ZQ4$T0UTcKQ}-h8wrB&PUG1&^t6}*00lUHV322KS`{aJ z@E`?5XV~8zGbhBwc^ucxX%`TX3D%Q*YHYmZpF4s10eAKMOoHREhNnzqcllJJ`MJj}-3cd;7Y7mz-8@)&c*^5GW}x;Pl=NdSsB zNb69FqLWSLbv=Mf`C@^fsd{KxuB<02b1SBZfimhh1D*p{c$S~xqgQ1{Tq}&Ylr6dK z9BP?H*D-T09-VmkL`8R*=cM_G8Zh1%L5QPc+HMvU;P@G^t4EhK8*XDouo)4F86<|6 zHB$$)3{8HWyV*tM-BfW$9hSLYe>V(iKQ}C;nR6xRg?qPf;%G9>LiIv%tGskh^@#$k zR?U7+S{+XJIQM%1zcRS)}LzsoT|;#N>l;O^r_O2E}1 zo{6gOy>|~`>fLqDro%VC1(-$gOoh-A`e1Vfl2hX>7M$z)Gh|CN5nW|KU2;i*w+s@E zW`&gCgWo07r}nwmFNE9J7z3)}$qQ**zdeqql9&}^sI42j{7@?^K$ftK$8s@qK5?bl zci!_nC@WQ&o7GGV0YQC;ClQTP zrAZ@4NPp|dxSiOcM#*-V^>v)#;efuMJTJq=O1E_txfXuJRSXQ)%ff-TebXQiaD?pfbQ+VU*yz({o2yXMQdwy*R7j_|4udWi^%)^yVR355)CMNvH zhyQZ@A*fv!%+p3cBe3l~1m5VgnX%ilZnJ->3@I>n>hyH_@r+vKZkAM7@YKvlg^(|b z-faqxSSX8h$x_3TYw_Q;FSX=#e<0B+!6dA(QcC>Ho5UxdU&w?l#J6g9rt;_XS8r^c zV=fEE^;@(yaen%dih98XNV_bld?RdRdhxl7uAT*DrTZOx>SL=wSxUB&>hFStWbeTt z<<}yCOPHj-_CW%BuX;Dby5ifnQ1B|{ z=1OcK-3T-_)zknv1UQ>gCOcAKvV(KXd}RbwPyn!C*mx}j(M3~!27nk8!?=^6{-#uf z*mXV!Fm;ccC!bcx3K>X^hU$)1R!nV!4p2`T9sM+9wtf2eRW9UMA)SP#{}bTC$qq=Y;y)#Kz7Vv*pqCCg9soxa|)^BZw|2ow;x) zF(0}V%)SDXpbnV^-t}~_Xz=hHgQcxR?Mr>Vdxu{t0G|;4$liq2rKQd9HIFOMdLg+F zH(2pN_-*+va5P51yT@t@d#?|rO$3vvbOUa8HI`Rby*4!Gq#9$5Pr%S^ZdewBpd>0( zjHUt5jfq4Fl(AAlju_P(Umj@LMy)3rP_acx@NMjc_>Eb>grD*+1_b3h85I_zih}fx zPOZJUt?dxhl_pET@gc}u3`K7!3m5`h_FzQrdFBwnWbR^Mz zPq;E=Y3>2{0hU2|51-hre1o(mM4Ee9o7{n22N+!_BfB;{aXUhBOFRESB-GWy)4piJ z4O(O{suupaxCF%$4%yMsk%RO{AvnNbaRE~y%rWpmf#*BqtE8sp{QGAMUeQPJ%ER*E zjOfTT}o}1hex`qG@XSgF5&ULXbd%%I-Hgu-qK)c9EAcox!)?7XaGH z37m$o^DCP=+L?p!S(CsK4#rKe1$gW(hy}JO{A08ngzg0NBzTKW6tygptJOUKS$h#d#EVa7B*2~EdDnpL7t!se z;^$3eQvTTJ0a-yOW=yF%+zNWXf^A3)ten~lBLbf(JRywIu#DyWffFX#>lD>)SDpSc zZz7V&bhMxgA*CXuC+i7Nyw2fHRyo z>7y0iv9=h}XvJNv@t~UTo~Lb5W0oR(RzHrhD~S|!kGWTM((@%eSy=Bc?9s*V80ax@ zu;IC!JPns>w+;Hnn&oBeMNcM)-72*dW|%skD1oFR4KYsaz8=;uMo_=|EX6f5;Gv&1 zbGr|Y`H3(3nSoV5-av1SaI~;QG}Iz*-nLj(j*B4-1)e;a0?Lb$ z(tX5rj?IVoAqC?-%Pot?XW&o$`I^^A-9$Yeo2s{T_7xBgzyt*9Ih1uh)fqFK(4l+- z+aXt1jc6LzTVejBLo_|$+5<9F@-5V8FI zRTYe0+8j_t78afZQV>QQz_@cpYU;^?Gt?9nse1Po77S?asar=g0wd92;F?D9@(qcH zh^v(C8;2}^8wZ~qZJ_C$Y3~2fe<53NJpDzVkajD>)2YFnMiN*viF7wIBh4}i< zvP*%x(9_$ic?tenFq6qbv6{8*fIr6A#mZ_k5P~V^=k*i_fdN2Ef!j|`_~HPtmKtC? zElnz%*%ErQ^o~CGpR5rOr|n1+Bxqj2ZW++fkSlS6(GA651^=z=zz&2#V7~^G))Qdc zrI2d!u(7@V^LrkYUx;>~EMy$Obk(hX{SPDvh{j;nSo;9G6~GaL;0r!oW-~5cUgtl* z^?h4@@Aa|{^!24Yv7^Ng87l0Cg)Nw|2^Uc?yllqaQUPAxKd5q}Y>ra8{A*G;kWNtP z1G8BXkgyraa_fqrJy8!`T5Mcg-pkM6736`f{OXA)6+0)VJ?t%1bzF?WO5I~4Ocs#_ zP`-qw^Q?&tyz9_#;54Z5bgak|RCd_O4x|bmyaVe5iapcKZd6ro~)+HU2ug*MEqT%6DA@f11^lf004f{U{Ks25cWD8QzefU z(!$~ftddtcWKXw!RQw`gS)_ZuKmgoNrru#ao`OhLfaHO{J~m#?+U~04Oby+Hg(aiQ zH{17^AL~XqU~ab8Ew&N(f}hALXpl?)@{iB(Sk&*`{LA#b?&b8~g|$gIj&a=rvxU}7 zN|M*_&QQ-jOW&+k$CIoc>}4( zZhzeEY|XQ^WYZrhf2OY}iM90IVz?R+Qxdo8NaCo%fNcH6?r|VS}v;_iY5(q2+<*nTVQ=Xs#iip&MFIF1G z{p86b$$M>V0LKROE}Yc`eLc2XfOAt;#-nVXe+?c$08W7?5(E@5fE&?dfO8LmN-FtW zCsq-dZ)>}*Ip}^QXa^$eKdzdngoNj-zaNm1fm^RkqW~l}C|n3MSU3x$Lffh9zAeuJ zF(0P&4_^~=D-6n*{P9Ef83Pa)eSLl5ngOB3K{`M-$L(O~IV2{nxg&san9>h1CCGVf zC4P%FR+o%cacgdJUj}1O=Ft1TudtbZ@ZiCa+34af7%cnx`ynP$4$l4`c_mC|%dB2c z4z&pl!xuRsrlbqip@+`o~`4s{lIO1{wHme62mU< zYqvnR0z(vxd63TZ0wG~}Wd)*mK$$57BQAc(T}(`!>euNI)Bs}oVEVH>l0@hf;4A=Z zl)TTM4b&ZtjF$fOqFzi5cphB<=`6he5~^&d;lx-Sn6+LI!USc=Y5RK+yDZrZL^n{h zY0wSf2*SObiP9jA<;bS+d0h&$vCW$ev`uIVBzYD z;+BCp#!2=ZQR8`W>h<^7%dO!DSj@mO3!0BTACT2OFZa`WOOs{(7Rl2eT%PS{GR4Er z8}Oj60;%AN#Q<10h??N%krH2k?7i2YHgaG&Ln^1%-O=#~)~s+MqWZg>I(Z#CZA?yJ zSP3EDjjVr~pFTgYOMdljakn0v4RG>6B@yO3I2od8v>|i|hSxwWa)^z1fT$GRR}mKe zaMo$Ylu#?OZAd`qV6zZ+iX;D{BAb153scpiTCgq=<~^Ur^LHiLYTnpihiv6GNBe6Q zC%W41-x*r@npEHzTI^%`lFAp`qTr>unB!2=?;Ba>&ofGYj~8%g>vN@k6yPHn`#h!f zfWHKrBpF>5AA#XZbThx*NR>kRIN`fCu}E{`aVy6=4L{_wL&bXa zLp$mUuJsuG7D!m%FeLhd`y}q2GL_xa%m}*QC*I9r>k<}&!K1>^FFD@CCGjD36-!6nS162`KJO}Ly)jI)JQ={odOuf+V z!UMWrvsk1OJoWd9dR98q{188?vyUuYaTh7#!u8j@OUm>S$Esa4bY`+8@S98cO zl`B_4*R1;bdZ!;QIY6C&W)M!X(hVO!zP;)D!+l>SqDu+lUVj!N=>Gl!D*f5a9QgBK zhXShyf7!uGSWIcGfSPf1c$np|sL{Xt=7FHAy$geV<`7f}`bWDBIf+5F+)2A!PaDpr zqeOMO1Mb$&1I(`W_Up{1^tgP`OrmW*H#ZlBxKUD6M9yuz=0#oOy0!L?lH_C%kLc;0 zfrVq{_6}Q7L*q0i?cB)3WCcd00l{NZq(uBOYAkZBOX)le4hdoJgi99Kp{OMpxJ>{E ze&o;UPMI|8y=Cs{*_e|ZH&2)q2@8t!04@klm4h~$}ljGy~n1QU?wWBjL zNu*VE={Q+olYeOKC+?|zR|aw69u5*C1P4jWd}fp+^da*8y^|)dx>k9gWrRqGj(HEO zdZ<aa)|R`8YeD zF^wV?BQ-uPCY5UF<{J@(e!323a|ZP5)4RoXfwlEd)9d0ojCD z*RQH?FX+Ez&z`4&+%9I|`mSo;OE|>9$LTfefCI=&3K|ZeA)yGZBAQ{$aPy+7H zv8#$c6p$*=V{<0J8X~8$pMI6l%f-j10h>+1=TcYE$Zq!T&~7h;+7~ff@tL~sJ#1sk z@58Rj*rZmmUj_ksGy*)jA1)e8MT%bcToIZg1c(F)3ALZV{sHK{0t}--{Ffu656}l4rO%`cfWGwfP}-9a`T~{`D9gaAHF5Tp+)R;F)^wAXS1F)Q zg!a0r;hL&*J^TeJ=KAo*Xn_2H+}p0eU$6nwXxczelCd#mRn;`0JOLub@@4_dDJW%m z8Lu1;9*Fl`W^qYLXA2>$w2$qED58L3x+E|c0kycA<|jh=kN5O6mx?Tlvb>M zvp`hgJG}qiZL-d4fG``B027$zaiuJ54eA`j$YzY*To2F(@d(0y**Q58WP&JjhyTE8 zQd(NhUpl`(?v2pk1n9RBs4T@kWY}l5S@8S-!2wOMu#t9MiU0u3*S!Ff3Dm`>{XBOz zNVKPyq}2ohN5nmY(3uz}|CCBk=2IeEkkQhSl3gX=3X~Un--EwP?dGlchiK8p z96RB47t!=TqhtNXIVdN@ZF+ev+4o$sljf)?b4{5-*$cZ>MVNpVun0JH8uGIS>#o zPTz|%wDm}IHm0dAj&;!)GlNZv@=q7-SZr-p30Gm#=_+<

#rQcgSCgIE|eqOU?&G;2c+6~o;*o3o-bqPAyY?6xTLFWDMAk%{)G0LSL4_VhG7f5D4r;6%JeuMeF0zik0DDe8Pd#;_h zV!MpGCsNbz=lC-wHeiwlHu~RP^1r)~eghr#BV8nDV57~aH^qF$ZOb&HD+F0$N=|ac zJE;7=?%@(r>q=-G(8_=1{ZawpnlWW`+O5q#z4q5t%s~&AHezt9bffPZ>G5%F1XKG> z|5Co{Wyz&3{`X?a+V0zf$vl~x1@4sQ&a}~{3kBlgtKh@lGjq}##wlRLTj%X!N+Cz2 zkXPe@6omfpmfd9C1M7yLVPTU*evvb*+9Y>nzI73XTB)>rG*M0d^(x7ZMRKS*6kCyj zFvfAYGFYQ0H_qw4Wq8MOV^3}`tK>bRRGT9?74T43+oT( zz4G)3`SE96M|eO#foF(^D5%3IEZVxTnmN6tqU3i)iDrt3Gp5&HOOhu&L@)V*)$afykDh)78C!77)Rr@p1wTWY1G z6vNaDrAVNj!wmvu1u)I#=l5Rz`lsh-0*xdLilkjbEkVH927`Z|c7u_6YEn{2*CM9} zpe-^kTB=O9KNUa(0FW6lFD^mkfD8g_ULqnQ@RXx;!Ba%yVs-iL1W|={!W@|mQByM5 zesIp7L45{TyDOEJ_ z7~?)vb1<@SkWjrBAcHHbx6b^xK+(uLA*zd%q>1&MJZBb?=HBkv#gtRfz*|QN=$M6z ze$ktL&xm;_-@SYfMwCE>2;g3_H&Pk!}8c=!WW@I4y0)A%L&*Gm{c~CoItUX zou$`OrQ!QZ`{ieqaclb2L4Kq`Qo&3a-nWCUQ*$6CmUb;Raz?!e2XT>3I=7-nsq&TP zrb85Z4cs@JSzH%$&dfIPdN>|+&o@<1n6dxP>#868wwm-RB})Bf8Ai6Te9$*1-ER*{ zx(YFFZH4TL*(`(&e21k|&UA$<*IA%VXzVJZebuod>HBa|sy7-BYvf9p@?SE5I0fx+ z@X*Cp-hNikhX_>Nc6_YN4y96&emHtw2VzsBoR} zR6-M=sdT^2E_mCBzO&)Se*W06@+dQ%A}46X(Pto*slcb;Pl3=Vh^LMDdd>o2APDu0 zK6DzL;sr3{Fj#L;c_xb9y&(BVY6lxb#c%$OswKTTNp$&)9-!d_7R0BIBw zJD|IcK!9v)ecbD`(jm_|N!`r`Op}P7jcTV*+rs!Z!W(U+(yvI!2<5t4=;w*tA2ahI z1tKMn&8CCNuM-J+v04u>4MWfM^Uk2!fGdZ6V}Qx^ zzN0DAK)b1Kj)vvv*4W08b4{syl<5)t(#OYBSyjTAIUyrwIMf4-D;eiY8b~gr6iN z>y5d8%PdQI8*YY_D+MO$vX>F)WOi5-nTv~d+I|0+xnrKY+D63 zC2C}QCEmCVV+Mq@6?7E6rLf0qth;Lv3Z-u_bZ=+9geYbqFsLUfrpkMGC}X3TI`=$g zH~hJ+-`6trP88`5oa3vXtq|&_yD)3X?rFj@?Ahz5shuho{G0Lg3=~a6(tNp!A3N z(S@5R`QD(X4oNL_gH8ndxEUvF)A?R7j$-Jpw&{+JE?vwy}#d z;28luv0_oqFGxv)i(XKDm6C=OP9QM_>C3`(fkBusAzubJSh~Q3DhO3RXwQukB#QZf z&wk=|nF>=gext`T9=8m-W$`?VqkB>}%6S`* z7ljAZdy7J4okmViT`Q7$X%zvMo+H=S)4)ho{_c3*Iexdx7S0t2Qj2HW_Bmry0#X6G zf>?5XfJOG;Cv;HGU?cK-z74{dtNR5+iN{QIA)Tm${II?jkatwK(Lfm!SLrPZRpP0{*fc%-YAV`m!(i#4F;D8)i4l z2)7I?1bE|&zgye~3?PK*G49@lItRdMvM;zuF?B)n1GNxD%ug>4TpQ?v|H+hQcd(iT z$RfK}zSS^FqSietVp<`o5&~>)e$Zst{&`eU@XO=@VPa6Z zslu)8Sdl)*@y|jHYy3ReUl^4?w^tcAz})W zIIwo()!(pndOy`OG7`o1gJZh1c;$4HGRamMK?xK2p({j^H1_)PRf|BRf;}>x+)ub* zVWkSI0>?V-oDu82kvi)T{r5mV@QakqdEsa(EA%rx^ZqzUAs9D(CpIc(wqR1?T!C~c zbPIhb3EQK*k<_<%&&3Xdc0W!vMmLhmPK-~bY)U-+2 z!w^vfERWT@9zaf>9v33NzoBf0I`=En*V_|%Do4kI+m|<2)%S1kRJu)H`v^|Df2^;K z!0-gd(IE|^ahX-u z8r14lTK{WCJ$;E+`muZu2hscbg8T=*wuRnz0=hHk!r9I9 z6}jVQiOU!6NIE|MVngNb+3c0u8_C2rtF19if3os|V>S5hcOmNloj&zY(mF*`QY3!<0cM8)4lj`Dppaa73)ViyoS*~X2)>jS!D9J+m*vCuRN&z*bwQ{24VvUOWjayqT5?}s(Z7Uy<4~6X+OA2OH)&W%)?16safl2LGzVt3|g!i)n;P6Mv1_2IoY$OEpUcN`yq-n z?&EWHI-|Ss9L{DKox!^TzYJ{u*pm9-@_GF_5*r!lopYh1IOLz1nOVxT3Q>7jg6fP2 zlF315<=*AJ6Zf9v&pkY)Om*Nq0rLtL78dxAJ{|VyHYRLshJQ_TF^)He-ilz^gq ziV~XZ@KQQO368S5`x_qtQUDM%K{hskIzc?$SRGJMaCO-$u8-2l_(hhFs?O*h-4Zy{ z5Z$>0(6ZKfKrF&Tj0v;h!YwXlQ+HUmFqZ)aj9I)|Ep^dEQl*bhQO;Nr_QL+j@a5P2 z`Y9_Np>d-QgZsz~=%b9U7SEL^y)jAo0IgVuJ7a%z*4pmDb`^yBqW_u_(mp%7!Z*pV z4^eaJXUS&$JEI93WaSXRM?t<(^>R?ZZ&8*z?dTM!^x8tNAO~Jj6ygU2PS_W7#HgB^ zuS0PSJ|}!9Z5&;4>3G!K&RG4hUY-4uiZRR8txI%SJ#Fr3!mpqLw@R*E$z%r}NxS(#(_&!;!+S z4W3Z8hk-B6eJ`lbB=a%2`P0E#oOZrMH)Ufg1F{jS6{Xk3gqJqd|2}^FSXQP~G|{za z8_{J4QtVJ3Qzx z0c@hT##0#?R9T9TB%=pbH$WnTY9W;9pg3s0${)T6xOzD-Fs=`0!Ez2{<~S+5kPZQ| z0f6F}>vp(pW;jn+@Ai2a#wR^oQlUE9bZRhZYbZe;_}?Q4@OjR92A2;uk~E~C^KV+i zS$IT-G&$O3|F>-`)~QYCgC_VX@cPyCV=4qgh5>(_@Y2hJ1A)V0D(AKJ6ugeBWEAwu zH>eQ9w;xh^N1{n0)*Gyf9`BzS+G|JmnT%W$ngbA1&n9$Z;#YTsWcZKOWh0FHVz*)n z=Rz<4Ob2DIq3zZCD2*5izb{IPrZpyiBY#kvmZZ(U`ldLq?ICBFJ#FwnM}}`40O9!B`%7eV++y_XR01$w{?4wx0dXm8m<<(pqme6Jjpb^>Qv>)cjnmDS$gSg z0G+~v@-nzHGpjpKUxYjEm@^~awY7{2CGQ3qklz+EYTD@=h~4fR+Xx-?*3y0!;-UCi z^bxg0!mt89iT2#WWLkdymntehPV`qMK9SFybD@{vz8wdhZJ+DNAA);#m=1+L~r zJ*AOP)B(3~7SBl>9b+voDTEBD*tG^fjgfyF=?L5-YJlzKDz&ups+#8bZSZiu>nqIjs-Oc~j?ecAEWbEl^4?KC#j^g-UK8)wjT$*H4G z=v{_29P}WxG&cvg8_hL0N{asjCT<`M4rP~Sh@Ah+@&bsXE05JBSo|dL#n=VF1 zMndtnIPU7$DiRFgv(?8Fgg(X&R9}Lp6vqw?9gLMJa z@d@zo(&siLZr-@z9~kI#Zb#unCsm-5^XU^V0YRCQ5f1XYhy+bE?1jKgU`sgB;o|sf z+^oR`ujKCJ*kfB#(*b3WWQ>ex!WRo!-GVbP$pqGFSrdA}#%IiGZuWtNkvz{pp~czn0zDbPQ9jA?EG$k_b=Y@#;C9j zhE_BpK>t3iG+}YuAz}C;h;;D60023I(luy9qZV9AuTZ_ihllUs;MhXcf`5ZRkB}Si z0#NmbOT19G=@LrMo}QjJebZr`3N}w#+7)oPZ;Wdc>p*iW=y?XrvGnxx&5jl@)&uRn z-eq6Q^ih8j538_n6WBAirYiu9;#k)J@G9UNcI}26SAr@G2CvANn6`dTjdHW7hzKkZ z>^t~OX4clQ2G*^5rEO*P1%3n|?Wr~#!fq#0a}~Agv|SXhS{0$W$E;Odta<0REpPIC zTVUCfBn!B3xT!|o-IPt_SpM1K3kyG4Nl6&z^qX!#$zGiyE}TrTmSJXS(ac9c?-XU2VcLa^q&Z{EV@T~fJ@MD@T*kz zL1Skw=C=L&Y(5ASHK?r`OR7e+$xU_LAPo@GoSF&>VJS>(d9C7*>)B2bbo0x8CAP1@ z*;RC=&at`=Nwnt>!@YB_c zryQwF-!@*o#4tR!W7NGc-E~`J(7;S-Q}b&DANyc|zTWm9`&s!}1Q}X{4xEwWalRT} z+jGi*@y_*Ey4o6Nh|cl7!Y&I`SK2?oSXz*~A$_~6@r>UA#z}9+uY83Sx$^Sz7`^x% z?#HkL8vqjf$j>c`K8u5_4{&>A5RZ$!kqTEEkndxV(5Oh=vLZ7L9FK-eBqT-O(7b_p*@&sNKcg6Eylw23SwUbd$ z45M^|NFI`?U@A0fcwzvf^TwvhMB%8hINVrnnLg2(!8 zTo* zTb7MqEwIB2&djKM@flM;x_fO!u(4D1`Ou+xm396p0U2+GH^vJpw-=eAi0{7%IvIj= znajuvm}8`zyc{wVckiBSx4Ia52a7#{_4lIU;#Po?VvzACK!d=XnK9x??W80lOBA$h zMhC%!YyrL;u%9={&}NkJUm^U4E@Ox}KBdIET}54ecAnw=W`wk-_aVw-q0pg?3U3Pe zX?bInBCOdhcSuT-e^+YK$7@BR6oE~l^J4*#vuo?_8`2+r!pJ=TCacdj>aH7?Jke3| zkOHEBtzP&`&DB8ugJ0Ok24|Ya&9b^KE;f&wHOj9PHN4gkBwjMQo6btnqN(!Wy5gkf z3k`>K`mzYcyr8y*SN<>UpB^PKaH_kyjZDYAOa1$b!kv_4^IiI243M?ux%U!R$aBaaY|v zpe3+ ziu@v;-({PmtgpcnlRYN3lm^>+9ubibZOy}fYAcQTuUKR5NgKjKT5-tQ@m*_UAesx> zSsvD%W#7M6OLl5F?%n3%TwHvt#_?kjsmodPk1vZxPki^CbFfr6bEa~L3HGoI3^SKt zs>u3!q32r5@~ZO@d*6s~bE699Mf*F+P#ScOmdI1!&X1bK4W@gs8jEYpZyb8!}!Y-?QIKklp6_hbq zFay65+o%f-(QCIyuz2EEG>oIMAv0?$41NSO_2>6L6s^iafbP7`kIh4Jg z$C=SKJ3>%aOzgyoBC#x`sC2te&J`uQ2WAa{WZbeDI7mhH*{}bs7PaIQ5O6Bo)Y#G` z-utBM_Be?qP8>}Yk67@Ly1~L71`(>6z3;tZ6FH-Ao6|;d3=FARO4=OSm|tE#L`n=g zgb(yS+1c10Vux9~cuEC!)NMy{z2V@u!duumWGiwJTarT=!?buc#i*}Dxf}h$9cnqd zta)~^bZ5gXrI4RtO9OvKzSnuKtoX0M(XiMm9cZD9tETJC{IPZ-+_#5ha#7ij4Fe(tiekBqDOSUOpJZI z8uPeoli{k#cNej1v9b);DSHCk9+`HDoFB4UKhmpde93LV%qrKm!}e&X@Pp{)`V()h zg(F+&<8H=mqqXM_;{m?Hz1A~WQ%UJHAk3jNJ)A^D9?i4#a9prvOk5Q0<`EEh=V@_B zy*uBfL2aU~`O2HcEB7_I3AR_)O}-R%9RNt8`cwYIQpb3sgu`j7-~oD0+7Qo&ZtbJ4 z5@gXQ&rY7aSg)8FEN@=3PR|%!)c+8fO!D58t2Y>3c>GR9rN@>X+DC;vX`}WCodq9fPhRh!hAWK^}`D_FO$Z-PX@S}oZ;U6?@e5Eex=scJ>)qI!owBu1**P*Ga zD_J$OwW+Pc7krHUb&`B^j1T=)5+u)Py7k^-u=5Tr+kyQJr}ms*c=_S^8^-d+@NGn{ zSJBl);cX%sBqm05s+0Uax%ym=o+r6ZmGjidql4<*hI!>%*FqbDu`XJ>2;IOf;0Q(* z2d7e`XBb+YJn8#3PB7qC#$`A1aI$@I$4ce*sP3Y;CsS2CQ)fS!ocFDpE16RN!apWA zQRmv!uNtwULa)>A1G+YiA@Vby*G^O693_bulO@38utebS;Q??msFzh_v^{3DRp-+emfc`dS0|KdYCo5aPd@9b}- zKOU-Yj$~%uw{M=SlVd+AF#t5ove%w#Jmux)hGgmoTM;ZRWR^kFrD-=#n$XP1|h)zzyoRPlCBv#SOcDKGEt2JE_H zkBynYVdcYzSI%GEt7hc0(MABUm0E!{7;D+gj3v^$y&5XuT2$0K%$^;7)w3h?#VEPH-jFZg&xcJaR_HOVzpZ9F$md}{8>WaT%O zE$aQd-{!Ac6Z1(Pu}PkR2P)xJ$+}lc?C( z9qHOb)KbpEPqOt!r#gRUy2dOUXh{!bv^8%qm40!rLM|XEC^%DqY;l8?ho{w`z=V&c zM3$I@WNl@oP&S%l$`KM8r&W$ya@W2savz$k%6ZXRKkZIiD+`rt!t0Uj@6B&T%jU<) z8zq=J8h@BPbY#Bu2yQgfl_os+&wDTWarY|W@Azh$UnO1)RSdjuyA`uMnw?hO!Ls!Z zL}v9qrtPhbhqiUzKdvtgHJi2O`g8-&!>PZl1sjG%u7CMY73O+GwOXY6!0_RR(?Z|I zKS8dBh)BTIkLKf4<1O8hkQ^43{}ytp9_7aaR&~xk%+)s@LB_(Pu|C*N>O$W1!mXIg zSN`S+h~FC8PSm#T=)M180k$q=a9aKDQpEGldiRigsZtwypecxmNR%j5QkHdJwO&55 z_#(S^xI*$jX^*v-<=*L`Jubdgab~ec1`dZ1JC2Jz)RiYO(5Z@BePo z@9WEH&zRm4J61fIm^`>{+tR$HT@mK}~OG z2fJ4VP0i)lj!zX=KCJQ*)=)if?e1G1-kT;P`3L)$86W493=02Gg`{4!_PNgwxwmbz zy?c7EJ3RDUcrVpXh^Ijq5QfG5N*e}B4yU9W@v@!61!X#!& z4FY!>M{k{Z3}i2W3hcpKGAkyNZV9@bD(TCPu^pMj5(5jXf8d9HU+Ke8<9}bc{uBtz zU0q%NGO0k??z)1FE#R^Dlyv5gOoJ+?{!4;?qk&p!1(!>9CRngEBM*13OHzv|e#ebiE{X3c{gjH4< zIwqdlxt!TgMEGOLp7C|E^NNN|1a96oZIpZHJN%-jaw{oxHTu`~PHC z^3UF5!?pZa)D_ZcIeX<@N$kFyh~(3+{u2aoP2SpL-OZ^l}?fgjj$^TsnW1t!#9;*&t(4H#}6g#Y) z^2Cq|M#)<*hlkC*R>7LA>YclQxkYrU`S;^6=0< zeE;81y|?!3_aL!`f)1YI)BpF8LtERvv^SJBh!&)tj_?#Yc(=7v71pc1KhljP{h}>T+|unk8fNe8OsJ`0-iFflro1`mk&NDhY1&zp z#?F)yxj8umur7*hu}iJgq=t*9(A&oV70#td+Qs`U32U`pQ?BbykcBPX<#p}|L8<+% zdjjRQmer&wQ&jr%!6BzM-1cOJcO>G@kO3++Zwx^P{84W19S&tk-IVl$cSD8_JN%^y zpcZUUaT)64*U?Pm=ZD@Qj*^j)fWpvjn;EZl&<#srX$72)S;LJ@$yLtiDfGt38*y;K zulkzA)`D0I?oxA+d|!S1>;-*iw*eNNcB6SEGe-fJ>$^yZZV*7LBFlT z+^1ex(kI9%lFdH4Rzun%$iX5Q^?fs+wUH(Hwtd%m28Q!Wx4cX|z4j~d^M5fX4=Pn` zoh_4jqaYvUdU08hv}}@r~2f@YJ%`7)e8k=`?qzluFX$(2@^DzPBDLX z4u0)KyvKK6(XT)6SoBJuLcBToByz?}GV#Lr+Y69TsTlGaFx~tW3!y)!{^Id|o7D07 zkUo3mZM8!bk3%|9OCT)VXzKhGIo3@XLW8VBvY^DkK3=Qsr_E-87kk`t{xMW z-1zZ@W#Z=h3s(}tTid6KUhc9vYEqEGC;ClxMC?OZU$PPD_UPNU9&k65ieyCwpCTWHo&r*QKR= zck#*WUgF`!s|AZ!j8ZOsj-?wiwKJH=kFz(LiOuGFvu}Iv6XX2iq-rZ_vTI~HF=Blh zCPH3=CY)?D6bYG=Gi;}=SX}(0Zk_u=9;D*nxNE_KBH?IWk>bu@5DnY4pqihb)|0|; z)X}@HQ$uj)rH8zpy=u>133v#Y1{A+r8BDeh*x76SG@OTT^$2ILsCdLuN>@8GZ@@RN z69Z3&8lJDIFJ`elY40!-SL`Mai9dHy#5k*nKfYjwPdMk%TkU>ZGA*bhga*AiDq6Mo zml@qOHIx z_gVeSn{{nHM&@ZgRR=S0gIRHQ$7}X&`t59)2M+=0@5FXRtpNW`P&6@LK^WlUN2kQ9AkQV?5R8Nhgc(P~R+=`6(|AJ{GDYK6&r-vY} z00t4X5-@MUbO?>WL4+ls4N|_hAEKt_Gb(okwHjg9475HNJ7JhYyj@OFv3Uoj)V-7x z4YXtccFac~3N$Ceeuco9UY;Q!821sbuWl^%?!O)Egb7Fq$RE$2OFikLWoDMO|L_c! za^fs3^^4?aN3Ou*V|ii0sMfE{p*W`BQ3lryzbD_BKEZ@c0-eg?W3G{Nl|q(IdE1-p zf`So<7-gy_;>P(%T~twNy|I(bxq6B-UM+-bZ3+ELfIHxJpR94V`}g`BBJ=VH-s4Zx zA=4RysjGzLO1_UgDA$W5Ex&-P)AWRk^y9tVKfiwA=O6Vu?U1hbx9@jDv|RE{PyHuj zuR2Gv9#-8UKC8sXruZt%JM}}4$fw_&3Fp4h(CcbDm7gbs(BvUm4vO7y3o)wM~M;U|}MJFMG&P{zgal}k?E$a$C-YS?Kk zG%JzDc$k+il`r$~@2-eu_6hQ9GJ~PkM{R6gJ5u-!-2b?@uGDdGB#F8s_2CbvPjsoT znB&H;tP?Z$-j(9-;1a9}>c5u0*b@DT<>2bCp@ekOYfdFsTe`}f&JL6_+o?wG(@@!G zLw=xWsN^Ni#qldI`$U>-V2=|?JVfExniW9o0~h{7K^~$(O>-h`5)N!z&kAO6bo81kWNLX zqYkb;;O6}Va=_u$t$i5G6&Bvhw5u74pn??)4&ezdCGn(- zc-JoU9&skO43s+=mD%?^RtT3$$^bW2lwS{;dQJnP0IyMz;#mH*uz;3Nl|@q+#?t3G z_7`=|L#rIV2d?4_1bwU3g$ZA&l*?wAiJ{Lv41sTJ(Zrw|H_#adJORT3F&x?ctG}@M zB-x#DAmL9*_wMy%R8&=M+r{kv6WAmRS{hnfSSMi#J_qiWpk=lCEfGzLj)yiKNxrZ@ zqhyr$IyTnG`EGWkP8M^}h8sIEaHZU{#~vazTM%f1bNhJ{lRRjxUA?-?A%Tamh}VJ( z9z-)8vj~v!Zqn(sUmhH`e=-ws?a%KPs+X=0&N~iz6fO?iy9eCqB4l@39}IMo;$rb8v?-*#y_%_8{r99=x{-wm})|;%sQ%6=`t&dyNo9LJyPUXwDFT5I1sPMY+kor=U zXVM)OfqiT=>RPnzzbkeb%Jlb(uo}p_*Je*&H&=SyI7G6?A(8R#Kri37vCiRqCZ=Qc z0@RN=nd=M;v@NBzZ2M9j;)IOZ)}>{R(0n01OUfD_=jzOGM)MXUzO{B_j)8Agfta1% z3}uHn+1O(v%kOs)Q?;HbVgOnP#Wzrj2@UWwXM%-26sU+3nP0(X6G)A=hsPY8oJ^kk z%wu(7>;(%Qu@V0gWEZ+QdW1%Mh{mUW@ZG?A+Zs{F)^DQtsFf3%(lTfI_fJ z%FCC);0ce6gg6E;EveVm?3|r{AmB05(1;?>?rkhdYZwiM6NRtJYChf-53HE?n`60qzq{4CN^N z_Y3A;YeU{6Sv}K#|4hg_w&DHwO{#x`s|OGQa4fd+^4mkq0fWX5hD+_DwX+EJi{IK4 zW22(9_4QZa+STt^j+pN={UK}5%^Qt@;d`-!KTC)n)&28jh;C^jAU7+&ZHmpd$KExL z?_#^e%wh7DFG0RB>Pwzi5~3S-W@!zjggv|GdPCPrz2vmRV=8$+b`xzU^}G8Y>KGl@ zH8fm&ATMv6Rn|Y_5yxygJEr?m{7nP92uDbob@F=gBq^&bGliTPd5~?09BDGQ(5It{ zW2Yk`4jrhwAaUr|{Akz9S{sjpdp~}$SCx-@RMVNCoMiE&K`rvm!JAoC9jfw7w?||e zxFqgP6nxf(O|yd^JMmTlOf7xuTv@n;r{4Xc*`>fv04qgYp3gUzbjlcMYW{@n7$i%s z;RMALGEvki!J#q?QKAg2d+@jv!f9#|3g8Amgfs9b1QZeM>cE_IOoj5({@dw* zD?NJj87lyIg`fWU#M>LdvqVUf6HUC6IfV^}zA*?DvxtH??DK8zbWRYw93gKC-nk<~ zc6$0(a1x-J2)(3^FHPb7f-J#`kxY3#_Jl|X8C3AX%M53iCZXS_U*XIsM|DnW{uSVx zP2a5U(hACRQ)7jBV27zbc%2!WE^K#D)}o-5K%tb$eMSo7|)O4r_Wu?Fk7M|7zr4)*;0 z;X)h2nW22#SnzK`foW{d{!Ny1obIv*i?Z6Uaho3Byy``NsAWGpvC834obac8E==d@ zsxLj@)-ks1O3V`6Z4njDTxV9PY2Bel71P^4Lmf?b>d`uD;-oTj)YRBrm1@K$R(EC|^uXuZu)X?wwtd^udZ z^)}h~g7P1gy9Zg*qb}e4?APmVENbyJ>Xib`V{smp9#cM{0SDHDuhp zp zsOO=%Iy2)4zRS`n^Io@xKBgI9px{*6G3tQC0`6YDF1w^&?ABEf%S_M2q2 zvD(!#zmHv!oLOdY)D=hR)0%V9$!4H@F9C=nx&BIkkAovP&-)rQboeEAlanV&cxGST zWRDm_Um4HYO`HL%;aNd}x6^amuE+3nt>*u`3{WL(+d@gNQA*jX
*av~O1UpPn68;RUcv&W&|@h+LB{O@g(C zsSwYRB+U~ig*0y2ej7Y^xul~j_p$)FO87p{k0+ZSCD5!-sfKjNlFO#Nlc}hlUJp7( zdRkbq)iqYUx4>r6VOb<^Gi~bc@bQ4LM8lGphWrm^tf zf0ne&qpa{rbbG^$OqQoicWhnm=E(3v(H%FbP_&b>0;_*fvisg?GynZuRXzdyyhE6Q zqC6%ET@K_p&Rgs~C{tgu-Nc+0?5uLT?#y1(fU7`YAe+04!v+d0$mv9#K9wR{@?H40 zn|OQJV-*z@Kr-3b*+HT<9y1_sEYymF)<|gM#;zsgM0-hgpm(brl@~n@@LFHIGie&4 zO>3)4*ki>AhW`BY>N9MsI38iSxWlGr!5PBE{~>D?qu*Ew(88Dylu9EbdhDYT&;A`D z=8O|R+T(>Ma0)s&I)WRPKcF1t*t=36xPvqeY{(duoB0QsnVAn9fX2c5PoG)^)LNbZ zstd?uP7fXa_2R`JYxmw*S)3!ivK{m@Ts)OL_JCvWIUSu3V5rY{)r1;ujbkcbJWEfK zq4bhq>PVDump<})8v06Q~Ki`(=&SNWNiDk7gpD!bk<$%g2P(db@zVJlWsav z_~4nwfn;uyOUdu8li#xiyE4p+S4PGO%NY2d4|%)jK4q5jF71)2__bru4(|-Q^ z8L>f1yuFtp!x@O_aaXdCN;#KeJo9~yb3r`@cDg*>gMbO_2Lhre_?gIma&8AJe`2@s&W8V#ARtrDGGKc3Hwr z&3ZxXLic7Izp#QT`mWIInsy~sQOVJaGde7?Csg|Qo$L*(ML6%d6ut}n0(+mlJcZkmxdJ`l_*5nPx~$FD71uH1g(;?*2b6j+kGvj zc=Se|o{u`T{wv$|W#>I#A5{tl=~0Bjc8otZoYG!eFvJ{fvWPLwAe!s+x{>U1DGTXq zq9G?Hy>$i17Zw+jfpo?JK*Oj169NlKGFSZo3Y;kw7k%+WheudQXz(K$W-H$7{7}Wl zBcoO1jJ|tVIJ4wQx5(^U>uCdtm|duyAVaeG5jiVJLr7KZMCfEvM-a(TM}LKWVx^H+ zJw<3^&_DTy{?v$?!y8W zaom=DF3TC*3BC&fo)S;k>tmPe9dDMjzVCeMxvP_o#3k<5SBrr7Q6oKv4MC%Up=aUp z^29ln8D$UWt~s%v^wA-DLz39c85Ypq&3ybF(7Nh}wJD{|Xj|_bl66yEcU<;r=2Fef zP@bVMeBH=HpC~%xkn(_sLMMtet&eJ(ypl_%YL1hgV8Gz@T<7A7Nk-z3y=0>&92jj> z)f9P!ohr|8)bEq0p`U$pnB=ta91V@yZE7>p9?^ae&QHDLyEk0rugpAW)owI$&-oRy z*RV*;uF1l!HhV#aHrIPeftCj&Uzs5 z*dtBiiJ49dWpW?5G#<%23>|!2?SPKE7g0jV5CIBLvn0Vz= zLLm=L8bXH%Z8;g=U+MjhG^E6IcS%^O85sURiB4QqvWksVo)6nIpU{oxH%fV4G7sqy z#6@&Su63$T^WyRd>=#r4Xyg2Zwgmwo+`Wvl+G}ojDDjal@`E0G-JG7}W5$!z&~yfZ zwcN#f$u=QBC+9Tj>dH#x#)JE9<1;PZXulA!`Q-tk#?Lgl`yK)>Bki{}jHpbZzF;If z`0?kGWuT+uI8L>v#}Dt5kIl{%KC9gKNcU>a#(Gem!oi0PjP~;40Ufrhy|Q9O3scW0 z*hh)8ltSdz-ODVt3G%H}gy-)*8pU|lXpys3PS~z-jW0Q=t)b;9lDHnzO8bj7Pglp+ zMM%@S=H+i+xBYOEBB`e-B#MnGkjlR2d<#FV;ogW3O6HHk`0|ULW!@PYTaHt`P114? z*3_zMvdY*^Zs-3zY+{ACl5Xd+c1%X$Aayop|4*HqhWQ0$C%LNra~-2przIG=avtfz z;trbZiTCf%qim}knw>`o1F3;aNa)vHvEdcKL@FHnf1y{Z9w(4ilHa$=$f3NXyrP(c z{N7jQuV1Kz(ackNfSqnMt7iq4RP5`gJHCIZJor|E{h11)eH&9esKI zI#KV~Qn8rK^5cXxCywYm1N+F_(3F^IjR25=()3vh<^P1hB zgW1g)T^8t-asq8U*k$q$NsZJb8_BhXTp&N@DA$ zwETf4*@bGV(QLuo32q+38XuTfIAksde|@WPdibE2!og1-5#5LGS~M^plBW%MZ%K@H z*SJ~Wwe(x(XpiY!zj15#>M`3GCic(suF*GXe*Nrw#P7`PrHWV|$Ht#u%WpWt^(Bt|#+BuZfEU3mhzM^!pnK!dDz_EE> zRG2vZHM=f4?%Vg2A~)DG6}#08@)q?2yB2S$oa9nDow3KZ=YLoLmEaZ2g+IgnI_pc# ze#wDM`*=SlJe#8|r`9zHmhx{PUe)A()gho(z;*`MANQ)~P+BW4e(x!%-VNRmNGixm z#Uompn$Xy-8roR$dVw1JQrkmc{n5z~5x?J)sM>b!+C{l^s-6XSq~xA>lHTwi)~bu$ z0>(Dco=L)e=XNz!WDmH4qbC^B)%^otclHy|H z&i4QmLgA-B(Ycx+Y&Ao3g9XG=+RQdTZNJ)-OuK~4{$J<#;OdZ=Xr!%uY2AM37LAn5xp-P0AJ_6FO~5rv}V2Z$NX|z>gKaAuYIrT5Y4UD zh-ht-f(c%8)ukj1mh?xgO_w}p@mVRyM0aMtZd%ao3Jgf{&q?mTR@%N8FSDW5w3cff zeV5hMQ9FBOGya7T>-_M;HH9|SLv`MZK7p(+Q~$mk+MG9+=9`refVrA(X3Qy;I0?^b zaH-I~1&}pqATc2U^e`M$Kz+$`b8p0L@0UG!Y+Uqm;3ro#dv`+-itdBMOJIZIpSY>3)oGq^W9$a> zkD=BjbOQiv_E`Gn{{#7fqp0^%x&+0*8z|@ARpgH)%Mmv@i`9ff5WP$YJ6ti|OE;*L zl(wtl_64TQP2cm$SBX%QHq*OEK{g5<)2r~MK6K3R!M%GdsN5*ddU`IRge|sir@OBt z0-?3?>)sHUHSr$O)z*fmRw|;?fOE|;5fPoI-H^fm3pvFBOsb)|HdiwzQ(=$;a6Hh| z1p4F)yKc@5?URsdzIE%?!-o%npT}r{;1-WYM+4sSCRGAjyy)Fv@$2OqD{Fl~qwkf4 z8ybCwO#R(Z#>U0Pp~jA$JK+@z2oC>arbtQG`v`W>t)U3l1DZ%Y1Nf zFiryCv-1r~2>w_A6nD70(b&k}O|eFS^O#S$O_^_tPtZz#|Iov%IEQdIMl+xpHu9u5YZeQvP( z@#*?WD`!ya$b;LY{F z`wX?;oK$tzGRZH^{>l)3m*y8f+T=X-?a{}VGcrR&cZIOurS$f>!@_xM&R9Y~_+omL zn})G~$oHeWnpRY9|F$)E|LS7-ExK=quzwzfSnP&Ypp@7689`=Kajkua>v^Bs_s}zL zXBwSoJbY4(_V7EIZ<7U|hou$F1-k#dpy*rFcOeTVbQU0Q20b=pb(iy*bx?cu^}Qah zQy_)l<}{G?XrW<&K^17A%;AvHQ+7#7Nr<%Z$EHeODS!M}KBou#BN+fB@eE*Aj7|;6 z8_5FZ+t(E`kN{%n3jP6_G}e*Y7cZv7#Pqy-2Miy?Bp}L`@ZiCR@$1k%M3!EJV;3*X zY2!2PVpKc_oGRR^2BF(|5%X}U1gx&M#|bu`%(?lPdnx4d)vMAW=G(Sy1Gf+yfXf$e zz;ONqhyiHxIqr@+ENopE&bFlavpB=aqE{)V$hLqz6tiHak~<35(TKnvrp{aTI$( zEm3r@du7NLaY5YerA5)tRsZm&sQdRv%QGn}%`1;343F~O4`*d(}X z0TqDr0R9u))bB-xKzKAXY!!Hmw&^N_mTw;Gdo{5zhaM@Kt%N2uCPdg9I^5JCM&X{~YZ-cQ51RN_A%8+US zr@n@Yj6vM>GZ>YS3Y-+FnUEZ837!3g@7KGIUY!fKyZEY^7DG1-_OmB-!u2havY-k z>Gap`TJ>umj(|sGMI+Mf?Z>3=6Pb$iw0O2va&%Ru^Ydu0s+=U_dNbhkN1NnI;XY_;wZ>byTq)TY^sGXT zH$xlpyFTD8KvvLe`;8&z)TxtZBP`sLgf&y*_GA{Wv0TB_{ zBORp<3ke~qy?8uE1tFJ89!nYz5uQ4$3T=*}a1SqzWD6$9bE6LU?os{V0d?GWKA2HTz?W z0%yP<3&RYb(U&{p)kKl!1#9B@VrYbsIh8ydiMF_lgN~x1At9h`VTC3n(8a4Yf6GFl z0#rgS_K2XM7Jn=O#EmKK{j*HwslR{_mx0g2eWeF+ar(Nt=nv@{7_7c#kF0!XuWPA1`IqZ{J#O9P)O?QaWLWR0CJ1~MhjVUoM+S4m1FR8u)ol5 zk5&`5U!s^xM_U{8`KcAEq(x%eFeh{a;{wO%G8*O6)Akgh!oVRQ9VJ-95YAE!jR?J* zS^Z~iF@|}BoPhvBL)btGyM6oi6bwT!=^IS^pi18O_{kH(P#iq=6AmAf#NBQTr6eS* zMF#w_#1U@dC4>R_UV2e~Bszq=1FlN_O*D-!Uk-!Pb>&A@OZ@ar^FGWhQAce;gBmTj zQ1v5Zd-q~KGKI;C$=v;W9ILp{it=(mM3%wZB99*t37f0u6;547AqOq3Vd zd%=&sRG#!M4j611EH`i~U^^))`tZR6aQ!4)I=vub{&%gn4!r&1k<*;&38l_M)k#T7 zKx`c0=g)6r*QSnv-b$IFDRTXvKV5?%fgMV4%|N2EIqMRzVEq>80fNqKWXluo>5SNm zJV%d?F02wd6t1qIn;=4;hVrzNQyx??`W-hu3bfG9=!Vi z1WOSb^-WFxVht7Z9)vtV?#0?QWI6ud^D2N9^Qg{Hjg6o?5ym)ynWo?8`fZ}N<4bR) z)@Mp)l>4eZzdt-?gc#8Gs3+ls1JX7ozsWT8TfmIC)cOD;-&x#T6xkfgKTsZVMjyk; zipvvG{|ihX&h5P1Tu>YYj#`o;fDot;-ai=DbCfuM5~K(n!z7oK^>sz{MUSd5Q*ZC7 z9Lv3}x%FP+3>u&&Ba2)Bcm=wW2+g;fz7Xnn&f{dnHxdb2#ECA>4i3D>j=}WM@%=rA zM9g83W=5>%w zVcZQe^G*_yVuMEK`y9TwH3(Z3U}6D*)ZAPw_n3gYz)hR^#m9<%H0X!ldkZS``5=P6 zZ_iD9e;A(2$Pi#T7#!anU=kpYZW|{|hX2BF?jxjJb#+(Q*BwOYj@z_b(taBoTNu3D zxq)Y$lb3glm-l>551u&VLJ8>&QZlju%!}c}hdu@=@yuD*vml6YolAKRcbqp?dnhUA z-1lz5e92Fs12}QvZ@ImILR!;8K)BKP!NiC|v{drG8Jf@clLyO?8Y(C%X6t0*e%xU= zsm($KL2RrW+?;of*}h0_Yt^{ofi1$+8r0FZ)+Igmu%`iYWVqmxjMFEPCHEwjTev+q zh)>|#2?*R9+K*G+p_e~dqeXx`vZV&k!B;B9m><{|kly%Xk=i3(25W+jjrfhB(*oyZ zB!5UrzJLFo?ENh|Dyk2fwA`$$fh@PSYHdOcw&~*+7GA*E{!$Jpbk#vWhm8#3zF%^K zU?}>AH5SxH{K4k5kjc;xNfdW`oSMq4M#+%y)_URg0j{KdIanKn-qCZPqDPPF?aQ#r zS%Segee|LyH7wSgv$J2gFo)sx1_ z?WLO(2gj?^C%7*_2&|2w2H8@72GNzqk*(M4X%Y33c;n=zSN*GV!(S(U`8?+j1w=sV z_!c1h`uC>$3n5eN*V-74=N*2xM~~>Y+`HD67A~|Y#l(8e1PX;e11hGbb|>`8^V&b6 z8dSllsUzUdi#hha`}7IUKCqyx_-J~W5sJIa5bJ(pHIwW_-WV*;3M-Oq(XNuik99Uw2=5$|qv|0;L7F-mYNp`y8C9sS=4 z5qBke%+0N;@_XDVw=&PBHN&Ne@R*1E1FAO}WGFM&ufOQpCC3|hii)0|U`eDJ%b)S| zDg2xeOOc;@*4msCB)^PQKF_G)*Xkcs(md3dg1`-_uVh%NGo@r;mVk=_oSAz6eh_-O z?<=8ei~P3lEF(o|k3Hf;e)8|LRsArJB5#CD{CB^fNOiHC9g}#E9-Tsx3uhvK4UG<* zGU%?Lbour+p8dX(#N0vh8_4s4JSvf;q!)F>z}>MD&R0b_PFP=XI0AZ)nK`wxQao3hv@ndE-r#jv?D=Yg#cEVc>?WJJ{l~S-VfQE zk92qh$yHTU{C+m<8jyZBdN;=YnZbbbI&v z_W&!`kuTjMe0c#5Dsj<(T(!4n+j@KH-==)gyQmVmubE_FE+zZ=aluFD1jxI^Mmy1n z#k34Hngwm9Ud|QG<{OQp??h&e&2Dih7wH=#a+mZgMOo1jIB$95!BkkCBtddKF~uVj zb_w}eSs0l+PCb0F5jE#3jzK(#A4ZSEA3p=wuB87hg)ZBDu=s)GJ(nPIyWI|$YQ$_% z3o%pp-e?Z+dsx#KJU1~Bk(JfeqJz1ppZeBnL=&A*DiswK&B#L*z_#_lQOp|fS!vv@ zr`-evenQp)179RL_#cuqVQ{rOEDYOO7+~30S^4M-?)@FgO8;?QK;ZZe{nJZX8}Xnq zI*q&lQ|yXFtnHkG+?{^bzq`AC2EFM8X zZsNJ3eSt3FtlY*N3TZ%LQzX498VjKt4gF|pYCN9>IDmoESYBGnaMC?JE9-V`aqJr_ z%vd1nf$k4v3E^TEO2LS88G1PAN|lS~cK-Bv4$x~-5(5k}0e(PZ7A7ZHSjeM-?;%^_fv2Fd$V49g0B0Dz!zlT2)Yay#VuhVy^OE3b6o4e18A?nMD&hu;CqvBJiRO^mzp*ayjjo`jrcUpkO8U!? z?=ZCIhYtmlXL*JtajTPB_SoN{&FR4}pC}szI7T>xjT z3&C`~()`cvAU)}n9}`|JDLeB9o1aQK1{3Iqgx6=qGH7zeNrLsckzbHG<>}zevh@QZ zBIDkFlq=qM-9diS$(OFk*?Il@QN3rFIqepaqjRr5$LsaUp1&pDBVBGL%e9lsS>2K( zjI>mYEOIx@=BO34sCA!Rv{dza?r7Tbu;k*f%9rPh?i)YUKUNpVj=kzF8s}=&wwl(8 zQrD{0OGL_!ip*BhY)gu zS>^ltmQMff@1Q8*!teBSbrF2d$daLl23azAa{?D!|U_IpslZm1|KyR7js2vDJTQYvtFPi9lon554~E%e~_m)e4rBw-39GWNC%M_ z^z_82aiY&DX+A8H!qZ6Pl&n{S-e!|af4#3w>&TB`I*dE^S2(+ zbt7tdUZU$o5LIX#Rgzuus|}{Ay$InWS(#8g64=o=R~(QBq+d^0cM+X`a&mHL=$Zu!BDSNc;}>#jqS9@P*# zkBErSg#fnwmvfbrKo3@v%aRRkd=csK$WFPhj6W%EOa(EZ$f!}ZPZ@sij!yRA1P za$R`zfQ>-Cq?;m#tTnQ~Vm?(;Eoy~_mm&dG+sgT}nbL4v+Zb*4HxBQaKy~#Y^;Nan_>m^L39Udr*TW4tZnanT1uT z?4f1)3G{il!QxA8?{D{AqOg1^%l40NdEICq;hFJ@?%#3yF+n#b-J6mjK(zM7x$e@g zT?fIj$DZPE__44GhR4N%&eyM@`!VZPfEga3eu48H5xro#e0I54hS>O^~JeE&J*JNus}SJ zuPV>^k0V*Z91J42SXlRzheR?wLD5f0)8vj(!EQvCLakc_ycTHElg0yZ%hv|mhFz3Q zgrv7x9k2$zHocrhwz9r%xt^8EJ=gu5d4tq9xC~uoI~@pMGJzl~9I*2{8VKhB(DwxD6{K zxT*O0e!g>980`q!kgBEF>s9|09b9?07M>vX9bRLRgajZa#^l!&c_TJ7UayOjR5GxZ zgQK4!t-X+En2~MIcArY&{)8`w`qpLs+#5=W0YX2eIu~MB?za`^zOiCEltO=#gKX>P zFHRD$_GWvGul?~*kUOY-Z0lbuomdDfG%%A#LEi7bMlm*{{9`^t5zoRVHkn@gPU6~_Fll#()wH}x8&n5c@ zoac4)ZYy{G9PCrIqf=MW>AIyQ)?l#pQCkb=W}>irH4WX7qV#Wjk2Vy_le+2qKD}8I zU#T3P-{HY2BhBCyceCzA?;CZ!;0H=wV|i}j_w)EZko`G!_?4!-p{N=-ExlVM|@uvBIPr3I@e1VZPROyU;0qAe7wA!U-;srcG|7#MWxCNuEnoBHok$* zZw&&@4Eb^U4|7SEzST;2S`%|#hH!0_tXrp8Z8(?Sg#?=+^3_8b=_ALlx{E(w*5P9~ zZy33Z$tmiduz_+TM8#fQNo4h}DYO~Wjlp}t#Jt)3M zAv)F2?B;r`oM%C~(&$8FU_=nbHfm9e+wgRAk0DTAEb zVpH+Ed!<_U+~q7&Qj5{G7k1NQ=h{QgdM2vO)fjr@Hp=3e%}3hpzc5(yo2Z3pM@88R zWOHz|JASsh^rhm&!*?V>zg%_75-M+~2_ChZx}le?IChjy_1W&*35&Ts9u4vr5)@Ah z9;53O4{6>%9ak4amr+znqx>vP<*j0eh1Jdo*5EVy4cRiLi1nU`(Dgc6f3*KM4(C9*M5Cn`)T9r8U_)r30*m6auj6Z*V7DtnK&36;fH?(F%tPVFmZSE2nHv8M~FDq`cb;{!$lLTTSun85A492 z?{@bXzSp|&wS}^pk;eIN|4Pb%QKo|2ERo(kBQ1xEn&v|E1_Ao3P4q!0`8U^>t4#D{ z?ms?j^yP|1h$IPzdd5wIFGCBn|BtWtj^}#+|NlEpB4m?<%#c|&rIeMGJx7?$=wX%2LKF zPL9G+W0&UVqaFTAT4kgKl`iF0Eu=l^O_S4uyNAzFW?MEK_)>TG&wAtXmGf(>8mnGL z(mdDLnq%lD=aQ}!E>AXp*X^qGcp_0gaXzb4K;=aHvx5cJM!IdnN0zQ%R{t5RzFEa^ z`<=+8l;FxNE2+kN@%POWJ4lSe$DW&`j<cqo-Fq{9E{gq{f!chBMBF}|3(?^l-@pD2A3LmGuE?`QUNDVM z5|{W>{i3YneX@O43O!7m+u@RfC5+L-euHY?2=al|{e+`zS6^D|`&gRE>hKHOFGc8y z3QCnp_c=K06e{SRinP%VJMGS8w*7fTCMKUxP@Hdbs^k1+&u$0vYbs%d6PJtqejQ=d z{Sk9yU#i3x%AI0$KQ$6xn-qSgK5XTC4s-N$*bI25yxN=j^O)@KRNK&an#tU8dP5&d z$2l^kgWliBF8Z)kXe&deG$-dwt|t)X4?hdzo@;xp6SQtQsdSijX;b@`W%(`joP zd;YZ9?N2@4JV06OA4Tc?t%QhX-};sFIy#0?#htF9Jc<5p+%LU7pH*wf)qL#!lswc= z5yo+zfx_mq4m5@ki*gk{Z ztx=yA^1=o;ZKR^hvM;F9dVVwBjE%Ti_wdiKuW%DAQq9%_Y{(Bd@tH009Yoi_A=)=?tOZLsp2@ts%Cb0sMQ z*4xse-rcsRYIxX5c5h?T|Bj+A5Uhs@;nk}fdvFT90^mONo>hetIluefH$S&ayYRU- zNw4ZA&Zn-GM;V*=oEqr}nhbWGFIy<=EOI86UihiIQ{yWi@P#Hx$3BKeOY@0YEbA8o zn^Ex}z3P6+o(JUQm}^Tr;_52EJHp%2wY3<@R}t$dPPfzk zBmMZL=Ti3kQ-jf0$I@z=bsiN;^(8ZX&T=DH-s7Lme$QnrBP?Y1R>iD~Pl@}{t-JuA zsySPgr|JB%_3zIUje1-A$4;ellwZ%CeD-_p7xUy&i(KVMt##@poifjryQk+!XC#>u zH7)6HzP6Pgc4c$*bGxZg9OxH;ODw6L`4OT(qF)+Z0%Q?smK zX5OBW`m)C+GhoBEyz7;;|J{{{YErsl3%%ox-T#(NZu?DjClmy!JRff2_x0~bzc-PKhTZ-&v zPn?xtxWuz)#C!DG4+EOJ*ZYsG_vntSR=AtqjhslhX=CAYYpdXtT6DP5`|npSFX~_0 zERy(CZ#eTJVKtNgd4P9ez+l(!p@WxGrwanCpD^~jo^UCbGKk^W-l8(FKJNL)_}j|d zp+oy=H+}hr92#c!!MOruS{dn6g>s#x=*^?wYJr_$33h%4A8(^dU z?3#w5iOtP*wR;AgR#JB@_@8H~(!>(-DFiEjqt<^5vh1~ch_c=M_*uiPPLM3^j7e2p zOd1Oqn05T6nmVIgS1Yx+Bb{987stV2$3GE$LB=SlK=0IrSViY@Q-P9{o7n2gtkxNN z+`0lkO~ltWD6O+aEPZRWTNAFplA%l<)B1iY_@_`Z3{4j1%tl#7IjR*9*Sf z4+F_;feD9xkw%>Tn)v3riFaL5`YGm1Px4QnwH+^CsA9-yf7^aSh_?Ff*u=MqoA0HS ziCA}ccAPxb!FSq)i7LFnrY0I^$|IM|;kWiNDVml;4rCR}*}D%b2>6vr^V9csEOE5^ zbVPG-YB92A@tLj6vkrWTfABW8{7aT;bH@oG>jCd@i%hSvt$*7mQW6`B(+cNe?ioD8^_FoMN#WT=kCS%)Of%Lp>4l>wFh^r;udo zRpL>4a|(UUEAm2Bwd%t6e|=+gO*LW<JfFjT2l#*p0b6f zYl&>ESu84U{ef<_r>LW9_HzWWmRWHsaRxqoMA~KfwR(Sc%%;L=y+;K``ZtvIcw60w zIFI_ibkXo1T;8bG{c-IqqAE=L zfw6RsaEy4nnV{@%u zEt5lsOWI_MC9jLpYx;~Q5gzR%0VVve|{xc$H%B14$wcm@6)PPqTnqX zJUUWGvNW`h40|eff8zb_ zbS&PmdLhwVtSu#ZfIPr22)=sC$?4Q)i+8SfmV z82evJiSrxzx!uit#Xl>3_~td6D{sQR!cR>x$SaB}@oAo zn5+Wl#HB}$`j?apeI`rfk2{A=Ftk=iUW>7oT018DuxmWlusWJ+kjjXhjC}R`F}WDc zhj~ZrICb-EEeKf4yklv}x_zuHbRRNdtjE)@JIL&M zIbR<8m5W@Z=;z~-^N(^Z3hi_ul9A0T6H6SU6je6ziLQHQsPn{Azw-Q_0@E4mj%c3K zE9|8yZ-!k9_aCLY|Kz4DkKE;4y3cP^=dXK~6bk5EzAkvbZ9eX8n&(~X*q_Jh8_v=P zSQd0Jc{ye5Dc4LKP_r^PyjW^0;Ev_weBS-9~RhHIzZ?_fikBfQHu}$6Z&|o}h zU$3=7ddLLKdeHy2B>T#qkBRZIb{aF z2$U@;=^i(WDYST}XOg_A3TH5JtpLvI{0{SOq3vgs;k;9&iDa*DObi?#>J%!N&<~U3 zWa_Ivpk<(TRh3$iNU3F@+#<56=xr3k5BG%|o*xa{e3lDMK%8qC7irHpLs$=2{v2xV z^<3U>p)|90QCMWAkJZ7P-a1OY0G%B4*Oa2%5@%4O<0g8uIMw4g`2-crHw)68BM;LI z`#f`cInEWRbi81c^A0g)dBMogHhJfGm#O4)C>vI>#rs})zJ;Sf><0N>w)~vPndUb!5509b zYl2nSV`uO|^^E*S)2GNRcug5~wGl~i6=Zh^o znMf+g6;*5!CE|oa-kMu>E^)}o$?`dg4)QFo#~kMuQb-pGIdpVS-WhFIL-FLPYj$Lv zt_-JHIePZkWH2@CZFs+&c~~hYvNGRf@sMl6nuJbM>2M38sr_%1Me}aRP)2eN$T&%d zMw)2YNeb=kxNmE&C#B`S0YhQXQU5LF+P0R z5Tev9nx;Orx5;YEJ;9n#FA#LsR$hJ-(soirwq(eEMz2QGY^O zTJ&1*HC@mVxW%E$3lRWtMMAkyqU-7XXRsQ4KBJ;SFDeU8!O?cN*N>xxKRi5qo8|k& zL~{ENAsY@QNsnLpu;_t^d39p7^x=R18rTgZFL!XJeg~&d6 zMTh-6zlI%Z?#d0(50;2=^JstX?=~JD*jK)$!x_gKn%~yo$$N~6w7!NbOVeEFS*h|N z5f#$iohQ?*?B?RhQp0K&)8!hPY@}AblrM7=pVz)+r7HR9u?53&X!64}o^9KkKXN+v zq#ZpJ;Px?;P%Y3c8AMV0E@AQ4UD{q^b&0US49mCiF~tM`3@HAmW{Cf*G^ov}?mPrja&(L;{<2KgV~9l5F<52;Z`Z z>rF6scAt*?U68g;%<4f>?^T`0ZFGe}>+*HK>)8u;Zl=rKAmKjIUP2O4=)_kbG|yxq zM&dr+6qow3`QDot!I64_v0y2tm=EgY_vFL{Rf%}osF{t;1>Sdu&33Rv(Ui6w-Rf>x zUd>w$96YH_`;+~yPeXo=5L0Wy;v>`5jAwKWg8L&+9TAVSI)BSVwxN3>dL{M2_O?uM z>w7xu;Q2sF{}3wHKnv7sr_Kin#!#S0G(b5fa_!OFFEWQOG#^C2J^OQQ!l>S}c<}1z zH=bUW652iSMG>+VNuz_FQernq{M=Zws6N~En;WD~9{D65ewE=zkwl!LF`x64R#-P-(D$)0bxn663!FYMzep9s|a(||cA^HU9 zEX{_P{_x5jppx^=^JL+sLj7^}Sn7LI5zpIie};bHAy$Z4FqNRtjViIG=zm1&Y#7e% zXCNibVRv+oR=}A!8herXR6f>LRzXrtoq;0{2RE_;v#bsG-HTlaqE>EsV19?Mr+k4& zlVyD0{)YJ2dLEA9J6e5@siU>;1q-xMmN2n3N4$imFn5 z@ryw;aaUbT-jI?Y;xvWmLBc#cm=V~O<9)+p2Bq3P&5hjqB3 zjN>O;x*3)%_AnZB`Mgxt3sBaw%enOK*@DcusI$#&e5Xm9KmJIK>3L5xaesGbj%8sP z7tMow*Y}q$$Yk2Nac4c6j?N%%v^qTG;V;ON{NUGO#`Cl{kqHAIQxI$zS<@RxCC0oc z`X^6-g%|a`KWm8@gH2TO-=aQs#UBjpol@}0ja46+5jjfisDIdtvy|LLG zi5B!3PhMnmNziLq-W_Ov$#oeYimBCYQK{Cy#cCXTcVuK>hW*N8QTiCGFv&JQ=c37S zVzDc&jD%yk@UcnEiYhBt2M!EDD>ZGlv+lUF?1<&-+t_*qs8VCm^tA>vzJOB3wS77kT+!i8!f)-y zQ*=!;k+E+yp7b{B3chZYUMFpKk@#wVv44m&)8$Fp1(Cg%jx#~z$J4Rqq8yppR=~J1 z+AR1snxED>IRC=Mxt#cMijoUAw7Gt(D%?snPB686PvS;Ka`Bzs#?l|A%O;|)OZhm< z%iEsVW7n}DW<9Sf>ZmC3i4jrxN7pG{LDJ`p{YmEM%B)#pTrz{2Da=2X>M)=4zIBV@ zO;}RzTS4}QCMo_$7wRkb1U`Jj_gqj#Tl>8r-H@98*ZyxzUC|s5UIyN76PW$C=ed zlj-ixthpoS-mS8&NH=TKF`VE@ALyIyePQfuIQx12bNi#=>l!*T9F~h4e{u){jv2st}Quf)nztfKMcd8kN_=_-R)NY;p=IVp5-u_U#m+&2vaIsvRKf~fQJ0ua> zSOO+$!gpU|?j_ z>Lo5Gy{rRO#roEE>xllLT@L(77cQJmpb{6($?3lPjI_bkf(jD=C8W)24{LJM+}CbN zHyEdVzip97q99LayV+>5^K9MBXwMNj5)I`XItIQA%hUYCp0`GwfAvuo9OjdHee`RU z3i(`!>6<?gW+0^`esKyC1vEWk^vpQ64fSZvj=4i@7u{xz5kZ+fipI^VdRBRBp z)!*qKXBy&#PP80`0`O5|v}%U8FkWJt}M*=aY?X_z?qJ>#m-!Yof z2>^@OL#wK~x^2Rg$;?czXd$ekFw?N)B|TLdsrVg476Lktg+h>?lrSlYvnaP`iikBN z1F%LfDE81(4}1t{Bu0jEii&3U-&owg%ju{PDbKqP!>5|HW`yqdX*4BYz0aO^n#yxd zez2kSm(O9E*MGahMJ?Jvy6D)OwOHa zfQbX!F{RL<{nW$RXy1N=;NC*)LzW4;By&>k(%#OCU+`PQ4EiDYgs&V6(LBV}o zgK)I+J)I@e_>ECUS<q>X%c;JzYxe?H@42Z$P>LWYFaM`S zgiKs`8&r=K5X(czj~M`&)vtFTt{lK*D<+G6bf$F*KFHv{b>=Z|Bnwia`ir?e2A3~KpPY9PPLyU=eUuRW ze_8-_QN~MVc#qbSp@D%UQx833uq5c=7hd|{ecL-mr zPwkhkc8oWsns&l8$lbk+AOn6tfq~K|+{!m@%)B!V#CWg|Mv!tZC0`(t@LPQbnJW)Z zBs&jdn0=D2rlydfASMKfSTUgxl?6;sT4%GMs<-zVdBAD|Ldge3M&!H1KE^uy`?uMS z6_tEDQ5QVzwLjw-;vPp>F7<$A4|{RW0L4t=+XOCz+}4yj+=C!%0}$*qV;GD{EcWv$ zMRAAVs(X*e>iPKk4m)_NtEtwdm`wne%vJ+*Qv(fKni6$=z2S6*aG=52~$?KG!`KT>7(aWqzsLrn9gYa zRY0^$Y{DI_x}9}0RbC>WuKd2ZsJUG>s=8r;>6#o4cOU+ z%R1)9j%|KDaJ;drnRE&i+5i|2LT>n}*lfmPg@>54%UsMauLfDfX<{ldOCKa$4P0vcXen{9!a#jE>N9Ap}Z zn=sr^Jf?`-=$*RaB|Wm%;h74^1JqP6mj|vvY_|pmn~Z-x4u2>d&-?pt%Wd_@Ww)kX zWQ%Myd#tCS@ho7+AE%?1o@rFpty}+;_5+~`CB8JXS&hdH!fe(U_1^T#%F3rte_bA32$5o`yoT4|xp)6?@zkMpUmtFHC|lmc`B;m!qYtyOgl?HH?~3cG;x zN-3hbq2#)OL1Mc#e6FVU-IVSmx&07i&algQG3-hq8QdXkM7J}z0ib)xuGAx{JjLNT zUyzsAj}Z}|4k37lEfZlCu0iczi{l#)X&oMpiZ4FX2$~@HluQZ#tCj5Q(=#!-A<77H z0H0Dfh=$N-H!t|%tub)4cej4!UctK;cF4U%HaqDfnxXT;@bdJMMR|FppV2a@K(bX4 z|IrQN|Fo?YX*@)QY?1GP*=xEdb$nk`j&XHNW$QR3fyoFj08O&9PA{MU1ilPJ(wfwegfo!c!$5I&&r?j?p#FhrNFW+9PS-mAzO!=#dv?HPa{y+r zZ@?}WPvjIkFJ-Xt6Ut9S5zPc_GnT4A1ecJ zqu}&6vy5^MCW9gjoM9e=m!SDUc^JRrVdAPE;nl+StFQmLx3>)Fxlr{~4AbE=JJ{PB zJ~zK>3W`I6eu*Va?ge-OE>t+SRxqjh zc5EIbUN~fZoN~h~>lgqC#P=FZi&vu3X&M9xbJgwE8weP%6%w;TW`|i)ON@=ITlw#7 zj08-icw+6A-Q4g;uWW7W^UNoDP?P68YA(|^S>w3tM&(krN-`7ii#*^SxvhxV~vBQeafF^R1AJ1 zQS2?_E@6o^=Xov_zX9U%ILR(-w9?F;k|}y(T}nE=XIQRO401jCC!IT_Z`!?Pn9AB4 z(p1fuf>}lfX80RNv7>ONQr_0tNq(=vp~9E)g8P{9)X%D}wbsVs(|^1+UkhF2+#Flo z37<{zw`+U7XtC}NmhSYY!kUbD3Dd2-&+0p+Qd^bcMk}8>n;o=-nzpw#FnuGVqhk;U z5;^!5(gLk6@V27{@eE2{1fvfGx;NgS`57M0k0fz)!lZ!2n);*;cKg1DpPbT_fBzj?5;|}Rnt03VpLOElHHy1lFKjA=bcr*C7tGS_auDgVUBqBU zetwS1Fh8`D<1SX#)&!4cJcStCQsRt8IEX+7>|BOV|2;;p=r2ASalhd^&1g>BSUm6&dh>!DCO8=T- zyqd8`;Tk(CtRojI;&D=7#!#&?F{N{21DjdH=(8?cK5p*HI;nYB+GOG5z|?jMKz~er zTdO_kwXa?1kV3`B$jIn;UKXlaC$zP-VdLI?S5H4KMesRW;P>(Io2V9L z#|V?Urk}=y#6swKHrX_REw4BB)11*fCS`(5WDzg+{{02?LQ2UsLZl5Pe9g{G%_h=e z%3?K&&+yB|>KDlu7A*|oxUS#n|DRI&a5Ocf=dZ*x?%yZ%!KdZmK>P~9`d z9tf7f^)D|aj%s2cKkp-0MB5k;<)S2(zmI*VwKe(ue?BX(E)fqd;G-OhOi4lXKC`e! ztzdWk`l3%Q4Gqn077P_(+BZTpw0R6QUGQVFN&haSE20D~r&H!rlbSdR7c7X;(W|I8 z_{5*lq?QO*VW-6JB81xbCpk9ul>70&SNY#$vcDhm)~6y9!dyH;y4hKW;y#?pxN-g0 zF}RAc3#@p&7hDhmA0jNzto$o|-1UGUF5e}li>e3e(Iaya=h5Wg)H9JEC zkEyEHSDxkN@i0fop@IV4^<^2;zxSq+=P`xUu3bY;KDDv}0>waT8aAexDvq^0|IsB zDQ)sYf8Qd9p76~Q%#o5@>nkcVEDU^zJqIdf9i3bsjlX|Hu<@^(OU@Sv{@o;pFCM!Y z12417mm3yVz5X6It2qr7_?Jo^LfL1=;|OWcVhkZihEgl&+~02>3X((~^M7a45Feq5 z{)&H$E21*j=f7xDQ0XtF`|sb`_5aoQ62=bx_xk_&AD`jp<-icd&JnGYIuUkZECJ`^ ziA2?(OFk)TVjLF#`O^G5O(If=Nd{+x(($G~Qx)}Fg&EwL5y&6FwD~%+Ju~vsbWS^L zCOKG*AS}k~tF_aeIz|YX|24Yc?_bB)ict1Th2zNNSZ!@Bl9D{Q%ZW2xK2`2W6lp{H zjqVteQsVO3_q?P3N$Q74*Zo|yy1OF|vGM+HbFbpV#jdovpHb<*<)cY0hja<<*RoHz z-yWbGc#2Y@g(jKB1*7AQ2T=nx7v&H5p}MR@~<$ z@znaHz&TYpvhy+I%j6d+lY)r<(|V6+$3!X0Z%lrDCikcDc7k!dt(a@y__rr89;%2! z4g!LHUtizcV8t9}(aa6Xb8~a+>im)RAVeo^ppa)*!T~l7OS`!_91|e0Fh^XZ*ekCn zoOs>A;fHupetyW9IK(#jz|H{Oh!lWod2$JHpf7+B+B)s;-{(Jh0=viKHbB^WFN)G0vq|6!ZP*DRk%=Oa|(unPfvGl#XK3fWVOQ&8D4Ur%yDlO$@$VP{Nsm zNKM{79>so+l%@6O1skJFLpITEL9n7MH!*5BEBJhgfYR$+N zN~E8yOHnXG?t`ov1A}saSb#!e;e(Ge9&gss*7i}9pE4M5HUYyI+_a%L!D)hRFvf=R z^MmBmF`I~j45x`*9fDNsd(*SCAX;R`fWYHqYm4_m+2M;r1`K5VY~cU}X^R;M^Wwe@&=6Fj7Ci%2Jar2&th>EYzq0{O!8Ln( zoX140F z>EEm?1}uLsgrCCK1jk}$w3UpDNiyIOHsph_a$?@A+;L<5AkW}TUt!J~zGnhAc0Ga3IEz#pIpbcS1bIS>9i z-a1piu6fq9X(@63zq6>~)=Sb+J-HKJZgdyY<|T^5>tfoezlj^YL2Mzh?eQAn$~=$=~q7kII=TdYpP9 z<_P8o7?=9I?FkcVl}j-8HZYi)_ktH1(sa3XZ4{=ExS19olVBsHh(O{g`v#w9Yq1nI zIdScLHFb54T?yYDcx&L$+t4?Sy98%Uu51)9?+c?r!a3@M4FI);cW1-LFvkNs@)NO9 zR{iGBjf-cC!p}BKSAN|7sB}#I%AJjS!NOMtV15CDLaOO4(k{UN~yY-gro4x_(wuT|x z4n(P0yJ)jxuF@Cuo(RccA!@Pw!=_i+*7v^}P!BUe@hqlP2)(o@M$tH-zVJT|fkd z|Jv`B?jE2AMl-x0t_V)fdx?juL}>K9aQoM)%<$0u27D?^`b+$Yam$MJzb?wilbvbf z&lBQO+w`9Sso>u+BnMtFYR_;}4^+A4c22;~7r_ypU+^PfltP#z#N>NNn%v0naEc;( zM%N^`E|^%98_h-q2NPSUqQ#Zs6dqL~ntfp!``rXN?P&&St_YAvSqn#Sg#LkT+~ zMmU`%r}fyOFk4@oLVzIa_4^82B*uYp+krjClpIO+3lI?S2TSxZuEV7J5wx&VdZXM^4CdqCS2QspVkwVGOWOgw$RE{yP`2ekV*2{^iZt_1 zLH&9`Nr|xh{T54X2M*{NbOvn0Z0XN=w%3Af@AhU30saBY4zU{pi?HnQl1^1*|F*J$ z5pfq(=&9sE3cLxM3EYvbt&yGDM)TlCa}?@PFn!TEQL~)wEt%qgan5uNTSS+DpkbOB z_m;2qq(KQlBRHB-90E;=P)x;NyR37HIfCSkp6?o*12Ap417u45TD#-*(}RE&#L)fE zk+D*1h)P{rsMiA19^oT%dOGfIWG<}OtYeiERddjQrqAH^kxs?A)OA2293=>t_J}fu z1wO=JDmbE@NvBk(8{PI)45G>w9~pU7I1zsYk@@tz7ig8(ADcZR6B7+|bg*39B~46C zzeI7U;IVW(oy?Psau3)Zgb~0Fkf{&|5jYT2gn`t*40)Vt7_$sB4x3m*v^_Lr9Je*V zr$?w6qJ@b%KeGg47Nj92C-tg*^iVU&*C_zVf|xa?8zC+RdXw8ZL(IX+0QDZDr}RMv z*d$Gh)jOL?`mbZR8h1ev?ItcMLKC!aYK!LnyU(w}FiqZs#K$`ysijy7KnzMz-Bq29 zxRMxt?Xv1{q@04TsNyiwB|_ee@8y%!%Q11GrKF9JxgDoFCc&XZ*;*pK)BCo!OYexc z^!C?;<5b&`S~&Ic^D~R(x_y19&V`TB2d%+z!zSn|P*dI@t~#9IiDlxL5(&Brj)9cKx!wN`(zKS)DL*A*Woiy%*N7iu zi|)Aa!)cghiqgQ*EP7yP%&U)ShD{BjT2N#2ZNNbA(?4u z;b97xk3-)$)hASQ?@|_Z74%P8)5if%<>|C8ZV_ z_+5pi!&4shUm$1}0fN3+(CbqhScuH8^F6lH3-G_$wTK{A{n03iM z@af}6cBO9js;>$Tt!3B%k(o{<`uy)#{cabEP48_7l&-9zLMlZqd~83e3kMFMAdfs} z*g?HnG%07_Zd5C7AzuVl5h9Sp#Js*~957H%jm*VcRh5co^Ec)2x^AEU{4hl#Z)dfr6C8;byVGzqTq--GDp7u4uW&^mmDf09$osoKcnA- zd>G0tE2~7vuZp>;54wAJtZ(@%Xrje(`}GyXB-KuzfB$|qGPknkm%$Bq7)D{U%q)-{ z%*?n;ro?n(_}*4DX%-yXjV;uclk52LYy*OCWpJqn$1@&$7NV*BU}S&)6}Tc1sfeY6 zL(^9Dr%M$;9T+vw;Ko?|(W6I0G7wTAEUF%!eei%Z2(&H~fMwIr6ni)7l!fwhF>ORM z`mMNVr!KX^*`}kT+_3P8#(#&h|?k7EUPSU8o*0g?6kieRK7!37MX$To5t zh%OwjU-!d4jwGwy+LfIWo(UaNb3nlz*moAeGfs-zJyU%A{7_l|9D$8n#JX+uQ$^hp zmN>Ct593}ye4&P$%RY+pXOg5I7WKtP@^D{8-U>+NqHi1 zB5eQw+|Y0UeZI$Cw$hzX5bj2+%c9|i?~QJTR^DQ++*Y0(^+O0*B+X2YE!V@Ct zyL=t3N;rj4!YBZ{APQ7yHBL@We*UZv5*A_-JUqxFK_QV%lar9}M%fFWmCn@bHMab} z>~8`C2?WXc`T3zy>;qUTX+A%Bm-Oa#M7%5Q9bnA?>$$d{}q1e@EMt!n!;WSJ3Pb= zq|NhaU=uPwYHETn0^|<&!^8cNZKB5I+9}hV9g95(g|kj@2|hnRm)mnmRh95;VZZ-^ zRNd4RHyBY4bP?a!98l{=2k7_we1E0$aiu8F&A$K2vxHkZ617HZczEyg-!?YQf_$>F z+qVud|82E0HKP}V02cUjyK3lAi~9sjs3Jmah6;T(b#v8vO z5c@(OJdg%H1^qkdLCOpX4K9R`4!W{RCjq#wdvxU2X<=dW)5%p`u7Gn4Uto)LI?NmR zP>~&^EC3uvorw0{#uJ7|5SP!PP=kA@)Fz_W9?=hJ!;7*q=i{sBQ5}{iyLT?L3zh25 zx(<Vio7JwYo%a2q)^tg@uMnIe&W#^@(ECgZuX>g3l9>j_vk#|LmHjNUCC0ePAIn zLT5lLBu30(K_6rx_9BZR?yCOoW~a@dts!%+p~DP8RyH#ftZ$#)w=y6_Y?zQdb9q`e5i(JW&Wc&m zzRwHXuFU3LV`>4*!LEAN9V+YD*Z$_=wBg=FhNyYP$HfiIEGL?b&O`AtO4iTkXZ(dK zd<-H(R5I<$B(dqBjUq1o+4X$tkxM-gd?4>Bx&%s zqhMraF4xD2$q!vn&V=r){vf}2O2_l(Fq9n$*tyx60=J@@p%vTa&l6bg14GtdfXW0m z8GvkgA{D_PY%Ku}u!8=*=hnJ=Y5z;Kw>xJ6gpO}*6jVhc2`YyLREr3`9MXojjg8qk zIZe&Y0-KLychN@o&|Ie7p!w&r|Uo zqBr4E%m`JN$LYlf0f!-4=cMUc46Uv4x**btQMqGfp;|(XvwhrW>U5jk7{r zQ&UrO@7_J0kUvXHP~9P{8Zxp12OiC^_VQj^QQ)|SnjbZcJ}E3jl?^WH#&tm z=XbAXCz3xI&r5v}qT&^HR$NZ#%w`i}&ah~3Q^Q}qxfvt60N0h5d*Jc)5Zk-^H4;12 z36L?_CS4$ME~Bn$k5wKQdwG5(5LX$nOwtCNI@U1`*8w@6o&B+}hFvhV_9{X}U*8(- z6PMJ~>QapU>(#6g9&)rgz-ODK?V(bNIuUCRgX3AxscRsJ5|UWPFzkSMVf(>~5U5+t zsG*L@#Ww^`@0<8G#5S+cjAx;Mv+NA?W|b*>oXyxB;sOA=tn&Q?rIlBZiy-Mn_mxB8 zW75pz&GV=dqI%S+BVy`Lg;4WqLci3D!5OsfKrJOm;K;H4) zWes>^Yd*V*s;*aF= zpH8}Xg}VI&$9a$2zgjj#{Z0L_#v*3w$;o{$G>e=6v;FOyWJ=KH??o?*iuQ|_4&JS6 z&6Eoev7|4r^!ITbwN=g%=P~yC_UC(5QgYqsQ+3V214OiDA7Dr#EGIW|MNp9R&Y1Y6 z9>y?EE-o%HF};)V1fYDq|CaS2Cp$ZzNp%c~A-c%Wto2z68OOcM)>7(m!%>f}{4WR42@yr0(X_R(QbQGo z$nr!iL`D78?$T16^ARQWjh+Tfbrp;1_u+vSB@gXuDYnG2&ynRvz~~3Z?b>7 zK-QE~DT>$WB=x~~exo0z9VNfST3-BNYCo&_@ztw;QT&Uz;gYfj(#BI1J^@UHMTcY^e~LQBvaJ*j{axLyrDnzV65(WIs z%l#qt`r#`L(`q&;(;y1_=?b%a5xp|~^#|I4X|L2ts2W`6h1wI^p zy}kVEgU!;9&u2zI>6{SQXJGS;R!806P`2gxvw*=o&xnDs`DZ&DZx1dnm%GIgWvW_n>T;I`@*6cDjsMW4Jsqmf)F@{=ubg;up*t%&~ zUmv2YIXt@90>Wxj1)K2R zdBmnwkDh|UR@yJUaYHv-3(W(~f?i3ca6oa9=^@?+Nc`B&YV)l68Kk*3oy>tB-o8DL zy$COarrIakTUwmwY>+ZwGfTDTuZ$oigB$ATT{_z&W(EfISHB83ruzKM>pvzL+JeZ$ z!XmSE9Q`ufPS_|k`Mv%dv4_5r-dkvVQt#ZkJH7TP&yB&WB%BicGMR|^0L;A2hW=c{IAA}6RrE19hxp+$2ggDL@ zSZ@77u@X$;0*O=7(wo4I2SAAu5gNMTLzjy)vZlr;PurZniN*(J4ht(QX+lGcgtAK>Z8qK2XaZnIP-%>(S_8%zPZ^*y96$mQT3)~2KHXlEw@;W9E3&^^JM zL-TOl#TzCd@`@}J`h}ku#tnDr4?*GM>$?&EskEj?pEUG0)erU-3iTr`t1%W(}7tZb`+(fx25% zda81MX9}WjZm%^q2+eYf{n~TRdT@dPPDHeQ$$D><#d2D$sh_ZU>l;6LxAo`!DG_<@ zcgbyP4@uVO`w|;hzBogTGe3_^WCJ^_>Gi#^DlNyl~-VZEe+Q z7y-!09A$l5@(cR4FJ5C|IOY%?((vjM?M>Tp*t>5ZZp82}CGzmGA!bwzaq1j{_!q@x zXpN+i77baXDggBewjIgd-NEO((Yx*MCngvP@7l#D>nUZc&;z-n@3&wshfi0$yjC5S zCz25_LzyKBgJ42M7&#fXBsBI>g?9KNd7B062laK(Kc_L5f#XZz;zfL1bbyu+p5VAg z28bBPsvm1|FmdoC$ac_uLp}rfBTRmuKaccf?--7ROYxU0f`tAgI=iT>g5Boh*E8CxdPi^`AMj=LBHFCXO-Xjq8+a0z<_d;rH_x~ z#}~Upbs)Va$nCKaB`4QBzFjIVPcLsALA0A>k#mpRCUa(Kv)8$#!kREHu10DxRo1|7 z@_|CHs0BaRzkfd#A8tb@c=ynE4-E}T zh~!iqs#+7mx3w%YghkNsSzgaQMJHEr1+ZLk^r8uJps?=6JIRU=7NSJea-?KXk7bvZ z5_$kg(acsP`E5_e1E+oi9`RCAc=jPn&BA|Tyd#y^_J7-%$sxloQOrT!oc=jh!ENp4 z2^hX%ZjW=tFgWoz;Mv&T)atTDK83#=A0KJ6^KBMF^?iIC0u#B@r*ky82{_ocHr18A z-Z+%2;7f-6r0R%so^~}hVbm%h0%B^069{TEu%OB5o3=}$J$M&dNNirEBJc(mPJB>v zjZ5hdZaM)!Jvr$lnF7(961)>au!Vl?wO=B5Wur>e-!FlO3uY?tDLi&eZo+j6lyfv? zps`Qi*;+1v&;oCWI3>D;JyH=lX}!;q{mC#D6juz(-Y>9$<%^RB2;k~?I6%Q%IzFgW-YaTG1p?6uRaB_1^?(71oQbUHbV<-SCHxigh>$oY(98eixbci+0wlt#MZ zJGrLQFY@-6DblJ6*UmWO?XQL8VjE)7Rt809BF?a^c2^RM?MrAFeavm&!CI@98+|pw zdxyET#MbMPaCi-cypoE;eU_ClN#clKT)`9zy3q|(i!pau1^L?4C%x5V8UL7hPNq;F zKkPqmfm6@oYtVrH35KG$o6h~wY;Q%wI{MU(z7*X4b9~)#EDh+sU^HXw-P#(q z2AC&{QF**Ty6>;ph;O&&DO*HLNVoBt`*F5DpRl#g?pza1Vu-cEGC~n!SYXZ1AD+Z6 zF2Z8_!+7t{nxc9^u;78DBZ_?`TQj$3W~i#ip~6+Y{IER64y60^5k#t;8ANMwLZ!;s z9}YX0sqr?IZ`srt!>sXz3G`a`>9=fE#ssslPC(EQqQTE2MU5+bEH&c&o!hrlAujH0 zM~S^=5W#H~RH3sw#vejh?!$))z%JMm$r=mfd)G(i+#oH zVY@XnH;y-MAgGB|WJi9D>6xLCMg&KinuirMXU!Y~Kfu}Z_SMfphmN3oDt1PkhX;>P zi>>Hvf9XY7^kdqGiQA?&ExxBIb*ifASzW?0MZA? zHLiIp`nMir5+}v<_1}7sOSKjmRELG1IV#8hGf3#G<3z-L>GzKWNa*S-b!^JJP9`-9 zerNENhbeH*Lkgl{3VpY#&pR8V|4u}O@x~bG-P9ab=btFNA6LFSpl&Rbt)QgaG%RxG z!yE-#^|^=>QMg214c!FvJaF#d;_&hPTIN;3ay|XmoC5_}IOjZm{1^pe6dzIh_o#HK32C_;*@-nmlvs=4UY>bZ%i_4ac8o#8@GtD zQ2imo8DY!E_InEVTSjIEVr3cIjY10p4<{##V(!Ij-##HyxX5-93Z6ZC{uek13#JWJ z>}q(H9Q*B!v)^*QqxfYkql#5i=jpPF*#xPW!~yuyA>0#Sz+`Cr{{)9vyyRhlC#|9Og#Re1>2uw})U{rK5vL zFO>H+#N5j$l9H2?Up#*v0&-MuP#cF9u;u8@TQ$?ZH;w^+ujuf94Jo*tv9xh&kD-r2O{^|sPdlWwB!D18vWQMqby zA7IKUasd!xRTUP_z)=lfoe*H&sI$}i>)nc@?8y`E+bp<^7(0S(4H8D1EO!@|Kge?q zoTVUb&@a>_ZNTPri-E_4GkOXU9&F_^xS`*ORfcXc_0gj?8nN#K8q|$4bEBsIQ_b9T zPcnw6R6<4N z(cy!@0_*8D(-6hCj+aIUzmmCnL5X@-ojYm6W^4VX3Sx|)tJDgF>j<9=)LG2WpWl0_ z2d!^Vr&aJ$QerrAq{eAf-p6Mh8*!ugJaGDH+l#pnBorE!Df7|xefq>K<#8SD8*J>N z^6%`+2sWqq4cP05iHR%vsZ!Kd5c;`rJVdD!Qw0s!L|~v~%odr@{^RDxGW_C1Wn_l? z`YLcm5x`)z$o%v%sa4Ld!0wNO(iWq8ZFJiII zSzm-irFn2*fG}W>{TaSdoeOi}Ia6ty-?`Y8@Cci2{nSIt7$tVG!OoO_S+nB)A7$?y zPj&yt4{IaIUKuGQ$%vw?5<+BDNV1ZZmAyh`WQUB*N=S;VA|V};>?FxvNkm3wxSyZ8 zzQ6mv|G6KJ`#i44)pb?pIOlxcpZELqdal>>oB*f%04sgruCL7$2?R3yKYu z0#qbSy~G1DsXh=`m@$CqZgbzY?A#aQn`|()0zMPV#BXa`ZI4`Zb{0=-=GwLEs-4|D z!8<%?8!I^q9f%W<5i}s?RQ`~hbz4@!ZxeR~?-zr{%)5CYJ-CP!N!J)n};dX zep+(4S5fO6a(0w?C@Cne9ZP|9U~%#BIh@@BNsU_%vlCG4^MvkI(AKeWU6BjS;Qq;l%AadP7j@-O;3zig= z2GAjt4xUBGxZTUyr9U{zGPIY9*CRN+CPZb^nhTyD-+enI#M;ubq@HmT1$$RqhNqxz zIL-2GaMZ$?a36iqdu#{ib$mSp9-pb{cz?S>%hfW>v{QKSY@1HzLl;l?x3QcDDSj_m zci#16T$kXCU`1PWbKdcea2*f$a}C@W-k?(>YC}fJ#zLtzjFNH(-8Pbc)Rz z6NSTfU=MTT`VPJEeu=XOYsZ|y`M?G!FR=PrtwJZ~`nP))Xy^cY^m@-Wf*=9s>=aAc zgV`u({+#?Z0oxKd2zW^lC^#AOR)1Z_SLfgWJQkCWx+0x=PD1un`nVfY;#OF5Aia5E z(|L$1XVOr5bT|c7|+MoDE~t zN_cTFi>pV5g*_L{(aFI$UQ~NnetLR(aO5D`9?Wc@32?QntbMz8zggX~i@REV`KY9% z<;yPntbNefMx4WmrzsGdx@j4c4?l#bIhGm##^+Iiz(f4-;cHh*tys|?$)7wqG&qQj zin#3;7Z<^`V1KNxuFgx@H|P<3P5=rO1G@3oC35Q17s5;DU;TC*nj>8>l0CTpM*4SS z6ZN3tT#~tXX^Hbef6dzaL(R2HWnmH5*EQaEe1G@V%%k)0X`iJ>*KRJ)d=7v`}15QXNAMOb>!9qD=-^ z|C*V>Xd9ppu)unG`Hl@U;-tYoEEGSEMjzu0OTCVgMAFmKx88b?p1xO5(B>UzF?+PT zs41bu107Gy<+IqQ9mP|MQ8B<}P>cb30m1%C7jiVP0<3Lz7;|%N-QvQFyu4giul!9cXDLUv@AL>*CSU8zS~c{}8U1pltP#BvufZ52 zn?c8XIq&qYS%-~uUfwCoYmBPmp^<^!-sRsqR<@<&D(a=J905k#I^qax-~`4!16qWH z`yqc)W2SaCb|yQS4o+ql*zkbw!c%bQUW3zlOdWuM0H8mbMss^O;8K#4yE{8CKt!LC zf+j?!OdaPM{;&xXG7SvaxwxiMZPOn={tgFk>Oqz0KW%Ai)vJ9Yj^(hVcKfnZZ;Kcr zB$V>jDwjO++dl0CJBWML&S?K*xPb)70+~Aluzlq;Cs> z56mb!IU|k-0E+^)yPK4hgxy6a2ftop2e@us3 zZm5uR+JTi|$m=@@;+B`AdG19=YvV@@zo$*igdm~OybL%N8rV?b3Vs@LjF@nnOon@1 zQAkHv*OTVgF8&JRm5jiC zJ$oxWP4{WXeviVg20s(a3nkyfnKw4*e4R8CckB(j;Sb0k|44MPbIzIpOxzo#+(iKiz8TVK}7&;-?2~+|H>MZMHE1e#6-`~t)XEGdzSLnm-CXI z^}O~dEPKBS#|N}ko`rE99v+hv{D03?rbEP(qDfmE78r|VaBVO;nvOFPP#-`XfHySq zkmE#R)I27x6B8_R&OHI)2eb$8O-O}DQqmnQ4irI)LtN{RHTjdzWEluSBwg9uZvRe1 z5JGV5!`K1ig)Y4nX~3YI!kK9s@YKYUfLp_X2z0)>+W$;eSN`Q~@u%rnSW580(F4ud z`>rHme>|45@y70b`!H(mxnR=0dq)o+c7e(Ifu=IR%U{KIJ;(f8fUAI1R*5dWr2{Dh zpa^V<#GqJgmP~!(AX z*??2mzIz8$-w(M-ejfst4`Ie3DXF>QSTuZS%qp@drg)G0g(ho$TL@P4Pfk+&8Yf0EZR>n@FG*Ek#y!#YRV zsf+a?32js?L!k?rv7Fe#gLNklUF_Du?gDh1DU?@S{9INWWL!As@7$r8yf0WE5grbi z7wR;-zKT`5KX?9pA!Pl#4hRbRs~sSFQ?Bqpvw?F%kO}Zyd`j>m-KU$2aV`En{H}C% z(#O$m9svC_G$b~vkH*d0TN-n{^yex~_Cl4=*%m{)xvHk7MBV&c5}KT@uB3}zg@r?V z)Vt~8uB4o?fVftZ|8?g%%kIa=15HXeVPO;s{N5XkgCP5y_SiF}kAJrrsF~z9F-mM8 z3*Dw0`-kbOefiH0!6foTJ=FZpdzhl;2smRs;Y zecIxg8^k~_LFaDiIj0Ez|NgdL4Q1eeBjkvu`)%2F^DhYS@3oAp{yyiw7beql_(!1s z_y0`n#GUiM$N$f#_IkVh_sn0Hv@d}E8`1q%#zofflG!kIij(Yl%ZdNlK^#oX#Pu z7az&<$sd%^48+B%ZU}C?c#2k@Jf+-vZ6xd9Y}YmiF>ku{iALVp9(b2ZW_vw1?a?9Y z?yodDwX(C>Z6BFl$=?ra)E3KE*jUGRkhWPbS2rqn;Gvb&<>dRXf;e0_UN{(4O_;l} z*6+y*P3e>MOR5rGY$ZNbp9&Ypi523$5;VprJ2JmKVddt&SE@%ZKhRcvsu8=unzH3e zj5pn77KnZtOP6{c>6&`<4tp)76nDfVP)<25MEP3Mlj|Mx+?YQt!vZveDN&6{o zr+23>-wfEiw#4z~cHQ|_>(bs)2kocl{4J_g22UzT0vz{LO12QllrP_MHx!rbT~J6N z{UdKWF40mUiNvzNWFh8{0>LhIj=sWcwyqyVzBdRT-l1M{5=>_o|Wg+5Wk}mC2Yf+q#qNTeEMaYBSVL*S?!%aGBWN?e%^BFWYeDlm1j3c)46Hop`%gA&UR|KM|djth&|h4(}jGx%)Mm&^Ta}T zn3`^nmiqj7#PsGBtrMl$C3gCPZz|KWrWXdqs4ljf%6N8uvUubz{wT)uLsHf@rEFkF zmv8!$Kd56F%cpDW6mYs*`R3y6(ySpBsUj_l;y0s1eujpSn%c#C+r(m{Zpiy|Y&{=(!)1mGpS2N9&CNY;)3V!Nk4eDxpWBS=jV@U%%?7fC zR86iM25K0;NyY0QLMva@>s=m#aRri@TG5WQCIzqbm@aVGEIb$M)}5Zq|3Z?T4{lYtAIMh0oaV zB11~?xj4?frDym3^Hk$L+HnU?=8^?>ot(SD9^~Bleh#C@qc$r=vnUqc-gQQDAmkUL zCw1U_d-Udr9g)7G>1l8F`abc!`_N#bpT%WyaPv!llg?I^c!Qf|-;>NwyHH8aoL0}Y zGIzNteo$XO$L(E6|M~amKSErN=Lw>M4-Kdy-AGRbB?Gzr4RTBOYMLJi zFi{rb9sVfX%N=`2o3CnVSiSIU*UE%a%+>1!-dA4V{bhOhN`_pH*m%42|JfPe=xziI z|Hky)kcl=$vyDru_;{*Xi`fHKrY+yUrS8zo*)2-7Wo#)a?ts1T#=+O(ir%;1H-|>r z2^!TDyjLu2Xgpi=-&-SVps>b%%OobswewT7?P^4;O5(-AM+r)|V$HwM7*!=Fx$pO& z+@I(5#Y$YuU@q9?OO@ycOGnX{R5I6S|4Imm7wV%VP6t7AQ|?q6p~-`X8Fkczw$pFl z;vU)BYWX}%Nk#tpw+K;+rNux7WC5C))SNT8qsLPnsGj@6@S?$Oy@Pul1_Or=#3^1r zo6G9;tD)$SdKvqtXCGevw`Rx^=!s1^lr`{ygIv`&sw=JS*?(qKIBx`R(bIm!FLs8n z#N)=3MNQiJhtXSSvl#i@&q@a7Zb$ibys-uRg1zum42JVSA2MEf)-)lKynxL{h zhHXpCe|a}plLW2trrgY`L{e65g7uiT>a>yE&{Qdprp8*@PtUo^7^-M#K>%$>$MqvkNl(=oP8ipMP^L>#h_}h_vh$NTM;UX z#S%ZcQ?BiePo7<7i4L_tzqpg6X&2%>KeKCleqkZ=zeEHG{Ue*Phds^DH#C)g=Mk+!dqYxYimB)=h*1!#Uih7!Y+&f1?YtVeV|YR ziFzihu)I8F<=qVvmb7MI^U!?DKX9na2U*D4?B%B&`h0UJ^TexY&1n z4J|-oCdAg)q`^Z%Vsb2H()Ey^AX9xyON;Ifmum(>@oi(nvrE#wf0*BAA*tRGxpX&o zPd3pa`%+;K>^YO{)ZL*?I^d4f`+{Zq%&W33j(99t*GaW*_?3bjBqZEVqhnY!pZdd~ zhyk|gt?=KU$)cwWFyNbl(Wz|=(lII$Aly4V+3nUq?G&|8%|gnb)rOd#1uuwWAV7y} zK#<8!Zu>wX%t(_o=xU++@oxbZ-s`jS6BrjReq)lv$q@t8I4>JUcjO%3O-@a1&SDN6 z7w`k8*PwZ^{yM@3S*QY0^Mkj%r}i1Bi?;R*?DcZ94PvPM-R8bSW7u8kyN+nT7f3de z^1RsgJy7hpoB9+Pcn1^&P zSY|9ZmnLaZ z+U6Pe-qhKA0)x~EM1406YzYK#GTTnC0WcLEFobx7w5nuO1a}QSHhMHrE2*?a5Mrn>>^;o9J&3`So@^MEupCIKy1Lt+b-@OY21^f<_Jjh_2e6q(M90D;{v#Sf?0F0E1 zco){z0?bxV5SWAA2B-iCKXi|nzic4n3+E4bZwQUR83C}zT@RQ6q{+fWXSyc8v=l;$ zbu=~MAngH#jhr$E$vID=L)OzvkP?6b(*e*FvS7=4a&>ax>to)vYrJ<8EVM!f{KCR7 zyM9sz5YB6{+~7`|3hyOHNKHF!bOd4%)j&{crWY9$+dqHCHv$_64I!;dNJ{b&j{CWd z%n#Q#?g~VTL=J_Amshi15;iIv2b`R_+SyNauEIct8yGp6QkJn00wyKZH{Kz@0@*ZQ zWVFK0bTp+21FJ)84#~=Lvasms=)g<+0-X^kF)V{{JleAU#{9%W^9*}T9va2|N)H;Kw9WuDYV z<>+(S=jehlGCBL)6P=vWR(6mwnAZ9lNS3_(O`RNgwGkEpS?DO7?t_k{z+J&Mvj*Z! zjjA8PJq_d^+G=j@4cyL9cvY^iPT>Z?pQ>gQfLhDU{1&TgII!mw?jhtN0hll^Jj%mk z29*JrdB7lu@?#|5Y+d}-QhZQSC{`8AdqVs%EJ-0Dl)1g$a5*B;63KLcK8!rac_Di*u{!WwJfQO~*Wh++Wh zgm(xc4Id@VeJ~g#rN8F&>#NAiLpU5b7!W>K0^qq9U|%4cDMx+Iesec`dO+!@ww^>H z=&17fXdZyu;YydStj7M@AfYQq0?2HSHebv&`&I*%(B@H0oTehd+ozfVg$;n5E1g)h zR3r%CG50r|#*`(nb_iLBIqj50z2&w+QA;v{#L$i<(R}aaeMa>*Y?=bGAM`SJ7TVAs`z;x9h?JX9MA`-1QKf-e$(_sbK!Oru2KpOwCK zP7aP|?bhJ$Fwi%n-I|E?2L|L%_$-l1In=v^g+<)$2xGDDEC|?r8`6<&diaMWAY7Xlu@rRz?=f1GbVXONf) z=OI$Pgc&H$BZx2ib?Lmgn8a(xe_ae?44&3Zm6%gPi z$BJVE5Hy>;aQiw`WB2bHb4rP1VJv|lRE{_^z|Ilzy*NxzWqNwP(#fe2v2k}V!Hrb6 zFYR1TJCPv(r3r3Pyo;jCwbal9mFA$7g)|MSGhtz2JJ~9K|IH6Haq^cUJs6>FlU7F~ zh9K9;X>Xb&)RQ=;l$FWjF&iKo(yNjIcY|A?guj2coJjAhZUzgDnrI zYY?gGiX-$|9+SREVCyD%P?5CHo+W};eDCn3H5(OMGj`bK>#;-OJ~|8F5tI?AUa4cp zO0*6dTt@u>NiGr6gcn3?q^7z$$fN5J#+Te?l*g7ng;IGkg_)7jN<0zO7aU;frbO^C zZf_vP0rEG7P^$q#w0pTd;4KcMlrzjMEbd7%*w6#sH8%@m%|hY|S!_mJjfFx^Hb=n0ZYN9U`}3j*5Q z%84Rb#?0rY4rRd2+93#NuB8tXW*t60B!qx^emz%9e}V?Li>k6RkbRJjcT+u1ZLvB=%Y!K6cFd&^?7ctzGt&KMj@JihAs(V5;NiXJKJE zPCmW#2Yd=#(RNB2f=s7z2B~Jikt^vYnV1RRJbpSJY%^J;1}Dnaf7aLbWX#S zL-(jTj)~)?5h-wHpG5iHQEus+_t={PO_I-y+8Xk+*Lp~&uBcRK@+047I}y4Edl950 z&C3-AWT9T~<&GVLp&#qteB#autd~E3{Gg-6rw-`m)QkH$J`UyS;6oRbc3oigL1!a( zk`byWM#k8?cdK_i-i)4#c7eRw7*Xb17vyu=Kgi^RYTiTsu&k`hG=H~$g|#O`#1Is( zyu7^KP7LLS8f|lMB(+-~_V(RGCaBYW@MWisovkfU=cn=ylvPyjX*~kW`sS^~&J#I) zety^mF?>9xu4HMVBzu}a1m8Du3sJ62E_{H;moi`vK_m(ZVHJSJ!GgsM8dNnYiHUnu zK02|<5-AE$N-(hwK=4)8(=<9O(`*T%1`61&E-#cn2L9xv=d7(lP1V(#6}4cwOwkRX znORuaNlE#a;MXu8vbN`tbO2(Fv&3#4r;A(GR~d=4`^r2s?7448XD`H!GzH}3A4-27 zzl%9JS~Vk;Rvdsm)I{pg?)%K8bD?2{LjVg2q zKX+zEo1XzfUDVEqB*h_Quv&|1I`7pit^tl%R;oaFfi!a3&uC~k)i9OvblNYep>E%R zg@o;lT~Q1B!o)k?nA=;nDY?{iBpP-6q`5tw;+26Ed>ns z%$H~2=*ZB{hABTLI=Y13g2mj}`1qbE_qeYvu)$xs;C7f^HjRso?JK^(_I5^g_AnC` zl8C{lvD4d}p}(xFQwdfWOr2}76h~kO!QPaoyr=HCm#3Sqn|*G1aS`hac0s3EZPciL z88~KI+8Z>KR8(lQvGGEy*R7k0cFv9e4o@Al#JFl`Z;5Sx^lgR_zWI=cJ}X^)$SX3# zCJ#DPkO?S2WrVxky=$p}zhx8JTjJB0Hn*BU=V{7M=!sBed1QwL@4aDbgM@6)Tg(>b^dQ=LDLd zH~NMoJ|-qUrcfe;5Yba`m!aHLeV~cGJvJ^50(=w@BcJ!}`O76EWWf!|=CDuT&O}d* zdl$WIzTPA4v3SD}U_?QrR>a9`XU7R%n2Ak&hq8v4sDSJTBJT=~q)Ie|Hvp;-(qh1o zKZPDO-2=wAj!6P6d(hLj$k#s2y4A;?FNy;E;Z z2my_e(R%EW;IJ^~ZTlw=VTUZ%N1}3#X{)99UyX!kgw)DD4dP(DPFqN3W*q`3N@w+=T5lh zR(RQ-I1e2O zPu9gR@9U;pe2?9+yvy_3Q;VyUwX569RkwP7T{mhj7n7S2Fyj)FL(NbOsABvW4ofm) z|Lcpgp5Fq}?i!5hijc{9{n}4v{N<3CnElC(yQxT9>$8VlI$g`4Ny4y}*n_<>P=8Ve ztj6eZ`&XOhcSuunMrO5J2YILn$JdvC9vpFyL_6$DEy!M5t|SLe``$QZ?X`OZ?JZp< zqae5BU7@Q#$E|acnS8S$E&qJ7`)>ni>;b($s z|19~xWk=@o+3sYEYF68pty`%jV{%ZLL!kXG8r9SP9)9a=FZnG;j_wIm*O2WOu!mhA zDp1!lSSg35jd}m?m+~{pyvV?48vhU~XjFMP18ym%qhfV-E+P&5`djndWCTk5zlYB~ zJFOO6ZPae9;^8J=XimRIkq&TJl*3sRBDZJDE%{=68!&h~I+BQ+a{bAUh3+3d-IXml zW^b}q*MEEVpKlv4PR$+vx4i+YpEqUMO0cYHX|y}540Lauto3_+WQQ2x_Z-dX*4RdM zX$MCUHo73%FQ2!OcrD4l%W~0M_TQUg3EV+@@D=OJE05Y~Oa${(U`h4T*<~vKFiY#0 z%YR>;2-`0$w!;7Zhxp}9qaUGK_}`;s==lF+0mN^0U;F>O3jc3?f)q9JKK}Q3!x7|E zhRAI)GOE+)=n$(Yd2LYrMH_${AkIE)3=EIjt?dice}%pkQNt_*!@Q{@|8>~N?Zn-Qxn=~QvHItaT}c#>XY>Q-baY;w21Jg$xqfIw=&o;GAAvGS z7Gcon3NQl0`4|Uv5TbGG%z!Rntj$+{3VJ)YZGV^DTF&23`sVRv7Uw6Q&b6C$l_-P` zec4fj{SYy-6*vS4CxQfh5r9j;6tjY3SM7L7;X{FL5B7*Z5XT}oXueP1 zz`(oz%?(uH1n4dQ;MZrjon#{7(LCT`etVW8ema1epJU~mL# z`2h!PCDa!EIr~CuZqne!la!6_)MW`gL2mM;xKbj>$mx*?m00!U*72PeW&8P>pV_^f zxxKtF?D~^~te*nwnv6_|Sw+<1Ln9aG_a{>hb8K~^Q?v9lc`LfckocIGaLa#BgSZOV z4#mxq8>vXnK^Xs>aC1XLvgqYNWu8|zj@T!}m{}$c(69ut{0frVvp?}+t1;tNlE{0- zk3|d3KK?X=bSx)~shQQalzCWUAH^}$lssOa*hQj=ce-8mH!<-w=TMCC~pai z2j>0+T5zW|XIXs(zHz!@ZTYHy3y>@L!d-I8$jMle5rd0ZM0~Z1ds<<8lJQ;hj^=^v z9XY`;JD_xg@u~Ip?c3b?lJ39yi8=?tn^Z>>R{`(vkPl)o7iZPAGSoJYA1e#QA_Bzq zxiM=fPaVpRAX3zwJ0i|=whAIt0A||3eW0*FHMkbVk>09z_Uz?ieT#;KfnFGv0hXX8 zLjHs%KNTeK4;{)Yhni5)$4IP1`-CC3~a%SJmU!ieBi?)X9|J9B2IRGzQhN?a55z2EE+Lp$&FNA`T^Zn1VuRoX5Mr0*uj=j!& z!slgouTyc2=H(zwf|{BoW3tL+MeEKTYs1fvCfu^1x^6xbobk2GrP&Le6u#u}6)T`2 zPcIar=U0v{(i6lx1Ga7~aUE0901x2pFSRhgA@jR9hRO^zX7Vu?Yv3h_9F%_ygET~g zkop0|gM3s%%+GBO{!0|$VD|8Ma7zd7Laak*#-yc_n@6SCBXGS!9s)8@+WQRy$6%Mx zHZ-h&?#|cOH;>pyu|m2)Pdvp#A`K8FqU3*aP9XhnxXByaZGU&9^zSQlLwyS_hux0S${(;0NYl$a zubS(1=NwQ@Vz!v;#21u_Hk}2$2M&zC!{PA#=TDB#t$AZ+!TYm5i3{n`Fbn)T@3HNf zbz({y$JpfIs6E_1mrfi=50Le2(rYu-V4scAoi5e0DIw!NXK0^iU9@h^daK3dl%E05 z?4=v|IgtjUmvW^?%vTo{>mX^-}i=_JDT2gUzKnC(!Z;8T|=v( z#`B_A-fVD&M3Qjj$*JWzhlNGu>1c+|Uq-|LCP+L>)0&bbrhvs0iap^X8hNZlAXdxB zzy93Q6HbHTe~}m)iO>tL2n;*yRzh5xu0SkOt*I#~O;6+Y0-BCRB8fQgUAta+ZNFVJ zNH}j$7Hz)y;41zFhh{X-{uty`py6McszjZb&fs(uGl3EPYD~|KX&kUmeOZdOv(fp0 zn;xzK0S3yX#uwO!larI-X2Tb{+aLH2TR8AcC}=H4?F(Q$qLoJ}gJ~lO2tsrX;TCT9 ztwarkik~DRh(F`!f9xy%;VeYG5b7gLW8E5(#Xg^vXKiEby>&R`kNwq>upHbM5^irh z)sm#_-y$Q%uq(b-xtfxD4lFJFSI9VjiovZ|+i?HL5uxk5EceOCtO9l4WT#*5d0752 zycEd8H9J&Vk0wF-$<^(j4PDM1>AZg4pjkdBrb5%UHK}fw-4RXFdfVj05Dj zoOW5?Mzei7US~y{%4wGLG)!l?t8nYhUgc)XEE{TE!bH zx-G>&io>6}?eyrmvuB+Uvk2M;c9IP{N8vm~LJZ8PFJG3`gF(cwFz{aJ+B&1r3Z_9L z0gLpABb@L@ECKohA=@VX60A2ix9t=ZN6}zwX+?GDWVMwd9sX9;FEozu^-|H^Hq#LH z^7%qPa^apyLbL2j7!NUB2{QK`pUKJkb#4+>lHaZqVz+s$3@bfs1w+g+ltu1KTed`$ zWJJckx>UU<=T-7(pWo+y^73UoPaQCp!cGjXdlNU+SqS=S6q!t{aBCo@yb{2>Ra+*+ zz&MnqhiZ|A8X!+gPrsLeGH5X+oCX~t5f%s9w9{zRj>z!|RuESo{r1hS>qVne45$Q4 z@$|bP=l728_UDjolh~*AqjXz(>{>-N!{dRmMh5xY=65=kq>_tNRcYkDy@aziG@Ndh zBHzKUVnA{){fp-w8ro zfp{amH&lJ|JCO#Q%P{g|hC|QCkA%=~lICwC=&<)go?i2%f?k=Et#Xc4N>A5Zp65L4 z2%u3-+U@0<;VQX59MS?J1`|--b2&X|>^qNcJaOVgq|E#m?BcZYUAV&f19~6|+ z{=+0SZk+D%w-3C@v+Z%2)E2j=e9aw_E6qeNWM^L+d}w_@{{5KQ3(oP**tV5{%zpo0 zTtaCtuKf0n?Ogr1siHpX(u!T>u11l{HC?9*OuCWeT`)4*qDemh*6W@w;Inp>Mt!><*#V_Z#?H3)aWxjgL+oz=BRrlbZ z@0u1TgW41id5WFnJ{T!@ENN}~KP8=Uid=igp`H8vgnAUz741LlJvPjWIE z;-2ds2B$JEXNV~et(N}SX4At>5&*Pevpaa~nCHOT00u=y1_p09_UMAm zP_RHCGc(ZNACaPP3PqbF00M;C zsyT7V~v z=zo2$q4!cp3l^p&H4fMY%o|o+F>S=8mm^gArL#RW* z=h)bk^Svd~txLacsr_4KvP1U8U){DVdZ-4H3eYsIW6sbq%DqEGQOwJ6ZoubH0?k0cX&rQ_ku?F7JEwn% zjhG*?yi)X$afd56SK2+1jiOOgO>A-MvD?NMhi+~mKbgip`NYcgX;M}QC7tW{I*!gB zBhH`HEi>OV?w;C4;dVXG{^vEB*T*@Y?6ON`2r(BjDn9LBGF;2MU20j`Hqyy9Q8lMu zb+h})pz-vRJvvgw(?ZX5REmCU_6`5&>*-vpI=kV7Y_HXGwy(+SU-~+;-DF(jmG%x9 z*tvMSKB_l+(D(Iw^XTC3pG|IVMfaZ0ot$l#r99&lKc17`gErs&TJaQV`=@r@pZ4}w zPlwPvJ9GBb=3jXp9G^hqoIN{q-hJc)`$Vl{bTslrWW!!!#omnu&*;u{LM!|vW?h*HK)#0a6R?N{+5 ze5Qq+kT*gE+St2WMOD?>#zq>sbudJA--v|H5fy&g>O+K;uV6i2($iz1qZ>TTpN%no z*aUC`5c4}h-zjPxNqPn?2c(X`8Vy&|#I4Z$86qp#-CZ)!1kx+i>PWsClWWO-^@@%3 zTrrd>oJ%g%BmpJVb&P$Qu#s)H26>DE&*sGptC;!o3%T&zEdR@_> zrbLe)okP6r8~5)2CSRPf{N!isVV}Zxr$wo>#P>4YfBDkQJFay$C}_2ut@)q?qtT0s zS53ZEX8Yy&O~QT|#Ey6uXFP3^^?1skRqdQ#H6QeQ&5Hh!WCSmY`gt>mSET%fpy0E zTG=W5(YhEWz(&OXcxVG0J5&v?6FF$aGy3ss}7xGXU&eFUif$ZYrSw+3<}HipKfcYJ2dOI zB~VuE_sgtXH1d7%fA-BT7i3E$cw9FY|>l0>Nbq3)CrL&P6X`(iT1$ILVP zeqa9d!Kg8mnu+03eq2Wv-P?fpi^;a09rhjHw_fcWd?b8$jmPxRe(M&I#uA!}{@P?s zmOxYaEsW~(o3eH+G@X0=sJcO+uE<^VhS4dVyq7PF%gfR^xcynGcfV5<-7gw_ia9&u zEm!Z}e6pRN9FJ}eq;ZUT{Hjj&T_hLl9I0jB#@fu<(<9NTuN}9O{ye&plVUgF@IldC zUa~T$xrXTs`+*O6;u1wte2T%=i5%n)O6koee78}Z%rQS;ULBgzkw&(bi0e1|<#HtL z=INPwVe}Wi^ZJb;>IsJirM8f{SBZp?2FxrhxTM`%tcwp3u`Wrs{=E8%A#Q{)z4)r` zsZ-mM-bL_0WboGg59`dIhuq}yKvD&n%6cXliy-tx7BuoP9!ULCXcJe7Mua6~=3-(o z2R~Vg1116l58Vx>2o&fEZa%@l%nT=M!H{n`LJ%d8)(-u&g$1%P2mG zL!QiGwjD{O!8Id}FbX4{WX@AYifw!CHMbzh+lbVj`Fyk%$WQ`Ne&xyt6zp3Km}4)A zilA=9=RMi{4dW@0+<@KCXDlNgx-;n&Qj?RJkx+a+svARW2559f-J693r?!c2Vg?ku9WKb_NS^wu-|ImR4 zQBFsC(m^O>evH;N@%dT*;zIDCFXwktvAkyVCDuzjav5*a3_3{KSFXA`JI_Iuw$gHP zC(2E)gQ{SrC(*1(BZ7s{QA+y;7$JcRpc={|hxjmLe_@a~zC;{Mi0jf^kC$HTzm7~c@0UQJ7 zUVe$pwEXN(N0fcnKCf@zWi#Glcwu~@qdX<3>ARcRg2EQckNfH}Yl5%eWvAM)z@x^x zkGgf5R$iJq^Sr~`RS){THHKL&s|_k2!#Svn!|E55e>4vn3sCg-o?Kwt&p9QRHz5{t z$+MtCL1s?pmc><*UeS!_ndd69-w{0LoV1dx>eU(H>W+d<6J7@XLO(HouH}4 z1iNM6Z3}rcTwKkDu+1s~{w3!A;An;b7m`S17?c`xl)V)siC~LER4+9J#f4HwIA4U* zn%76=*2nItzLtyTk=Z0;2j+QCR6~kX$JIafNJm#tsKy|a{tL=9)#wrm3JWnj>6}0W z)2rP7X#rweEfLkv^25g)lZW`q5~S>@%JEZ*>^E zt%#tqmV-dvpHx&B^2Kn-db4m3J0Ht4D1{J)nqBg~CO^ib%c{zO1SW0xi&;zXSH~|5 zJ97CS(}+#dU|`)Y&{)n#$9-2;Yks1f)Oh(Y3w!No!#${vw9(nG2YNuJs3&NMBs#yB z-a#0L^cr8o=g;eEDP_)M(dPuhLqZZ|yj=f{A(0tJBm(sj#=&o*RHiIE!}JS{+8mP4 zgmd#h7k#lkvs}@sN_Fy!#x_=7|LvPTb&HLBE7Kl6ul{0eny!lMfK)h{++nCgmM>NO z;D1(lUgB6q!*`*_;&NVj8h)-N57HF(7i^&)>BtSDk%uvtlR3+>xs~d6CTmj6w^;{n zuJnYobLJw$Sp#-@PiwBdE0jK0J9;I z5Ch5yz1p5V57+mlT$;R$z6cr{Qm2&E=;}mgA)2oAR;u|Yz*P~Yg`*nvnHyr)h&kmb z$(oao(zAPq9~SkOf1sU>2?@oWHj@qsJ$8uRgR~DwN=Nbziz7*$&c7#VFs~4?=RlTA zOJO;ITZM;@4-sjYY35I(-Zt^aa(=&8FMkMxpzel-y^y6~xaYUxP$<>$*nc*2^VY2e zyc5G81{BsYA;~}p0^q5!v1`J~s8>MHi|CTD?;-t-iu9a=PvxU&{vDpqkN8C|ux%rm z@!UU#G(41QSk5Vrn%ntF_k=?$`2#O={xrQOtWEkiI)GULXTivbs@?}EWQh4lYHH9z zpkebI3YUPZ144~RiHWd&!awt!HD8__nmR5-$q5L2sI3KrejWPh$#V)Q;31i1^I4Iw zv3g?Dm)+Yi`);1oqJ+zS@c7xKYI$$Zo5L}CdxB=DdX%ia{cqQoX0t|Zvi|8haz3$k z%XU+ef4+%)SI}J8&pBZ56(A&YJnzJ}isqU0jZ=AVzqx+r> zb_J{Z%=XjMvaxFQsAW}$$ka}YQoR+a7E$IMVzYjBuqT(SaVK~AykW<*{8{)a`R2av zZ+?v#!ziFIthDFA-C!C_Q(MB)kn`8n&}jC{yTV-lsu#?4F0-PRO{Noc%0ttD9g&X? z4HbV#=9ozndlr8Pk4nfxBrG8+jWuvT=%I-TGYhs?IVp0!vV44jDPpjdToq@-fw!DR zfwU2G;CXn7%1XI11BeALqCCXlHB5i2bj#3YrKQD?5nMiuLr0bBmx9q*ljGwT%UlnZ z4H_OlPL|{nYVvyT@69L31LP5o(>sF`9a6u4({^5srh1XzyifGsZRoS0_yjJRc;Y~@ zemRs)&C#eaaX<+I;5c{?VW_R*KMjQgWYUO{L-373t_HCzlAuW`L3 zGq_a1ick?bVhGU5YP$foRfHvwe(`erR$26jNcDzm2a}@NsZSXirXJtTV#5OI=w6kN z*H4GD5LAl9Q$<2Gt*B+niCAT97tzW z0nm}5nIkgFH{jscfD!$j{G`RUTOB=i=sRW*od@JIAS$Y46-jAQuj@3$H6z+P~ zPkFMc=Xx%X+T5;QV=p|KU_ZMq?C& zZ|=*Q)p9V8WFEcCCc*V{>9+tE_CMBdB>PtaKWw9JZ1lgQue94gj43g*Riv*S=u8Bw zSf+0a=9H~Yx>Z7Ej(JSqxn71uL>QJhaEBFb>}NvP0VouOa9-t_r@K26fxXV2MX~YM z@VGK<|l|H?Wi&8W|W&jE+ip z%nYv`XjD%S;pXKX!O&YQD75tD<-`{%yA9FKPUVA|-vx-cmt%E{+?rJeahGtbv^fwOuZ zCZ`mEG~%O_O};vS!4!=z)rz^e1G~@wl;2-^JWQ*-lJ%xQgG>LIj_s- z5)l)ZnDT68>C>T=A z+-tD{0PNJ}neA7f9ZbpfG8f=eprnx(*uPKqocoIpGlQpexzu!dm2T&>OJ#7EuC*z-=TwN9Tx1Y6Dp*)x=O*031njN6GxILJk9NCDZB{8sGeUL`cgJIh7C$ z>pug#Yg2s#Vna~|3HDl*~CyYG`^kbAU!89^N{a>`|tk!er&J6y}`c%wuWk@ zx6ncZZM})f1Jw*w6&0xEc1Yb#=!xps&&IYA&29JN`8BSEHkietc{;mI4BO!mfGZs0 zVGO1#@LgYn{2mi45DOX|U4S*4XUyv5yz16CIty>m7?|BO$!iZGiP=r_R4;!JjSA<} zAXF9D-HGzX))su?Tq5O%%;i_P^c%6UeyO;=V&}$%I23l*&+QJmy+Ep98!|rG=_|Cn z^BLFImaB9gFTVXh z0GNQ>!HA52=25i$h~I^fm#VN4Q2F)g8s<%Qp3qBKG$tN4E$0gK<(l}I^xkFchcU-a z?oSWguHS2=zdw+hm~OU>hVk5mh15^Ysk_S$EXlFT%%^v_-z~EE4{0l`#IJp$zqu^IWJXmUXIJPET&1)Li&>V=OpEg zwGVG@UT8QMF}*(UHNCu}c*(LkL#a%uX=yg*hT+dNA+af$Y$GY7qaG9YAH4MSZfoy4 zCG$#|*)w}3A<4?hxKl~?CWPm8>GM6F<2jD!uix+Y9MAoi;$GKvp6~NL zUgJiNbDzm``CA8kl!O=iRkaT>UE}Iq?ei>pU@{jyRF}@LV;yql%tLB=iDD_9pS#Ac zbi1j?;K?24uYBYKc92hAxV@UcXe1rlU-ytT@d4Uzj+T+m&mQb26qGQMqK&xp3~_KJ zuWb3qsZ;KlUI8B+9SR)NuMh#+k^3tXTKSssSEc*S~EbmDwnMj)&q~nX0KNCjkT8 zaO*{D<34i-YMdD%At6W-0qVG9;4zB@>1Ye#NiwV%%wjwoF$_hcnIJa(xmF6VCSX(b z`!SrewjbGcR{c|QOrGPAo~;elgvO&W7DIp2hnm{uvSWsi%|w)hwLeumckhT6RSj%G`J6+h>0#ApDi?;Q0#ZNkzho%#-m#TG&OrC$FT==} z^ZIEPNInTUK}uF+W`>a-*|2!I9aB0wXH$?`-D<|! zsy=?2(@u$#QE=!zJAdsy9bfkpFOIQruo7Xzb9=*PFbvQbk$DZn_iP!cNfX2Bu&`f0 ze%vg1w1xq`wXUc61*`9nsz@AqHB~g8o&O|p3r3ubk$g4dM4Rx6OF6%=oV(x$dTNl9ozB1RV zKj*K3GXIDbI!LD8xOo$UX30pHKI1yVwqu^2x8n<2M9ZG)$#r`~92$!Dx^7$GcLJiz zm&C{`Q++h6uFp=+{kevJ96wFbk;FZ>x*jA91_T6vQ3}ous=T9665XlXf+}OoJqD`% zv9Y%s2B2I{vutExYs<&Y4aZ6tVlYm$XrtkenInM*dpkB&zfWNb*8+3*acq`a6T%7I z*!cUs@&30Ys*R_~<@*SOHq>;}h2f#KCNQP@1ZTRNH%p3m9J#LLH2-!@Sl>YaFt5WeFujdSHNq= z^&jTB+eWW3sicHuR!2TgvZzO;#U3|vV$T~EX~!bM^t`eUmOFmFzO(6#U0pio&r7w; zfKy>}+yaCTa)fQ^ebV9*c~!#D7kt`~HBnZOoXMX8ih+%VZ|AUmA$XPf>@=RYaqwoN z%*PB9gSMwfj)5)=O1BvEXEbEN0KWqlu0`}jRu|jLfY{S_9elO!>|zGD{~;b8&`y?c z(t+AvplhS5s_NlUa-Q2H+q*?D3io{KXKg|PXZpbpa%V1FXu*q{2{Izb{HJJeZ+!j3 ze8gyxw{Htx533knkNui8>fB&%Ke+8ut2sHk=VVgRvnXzlbOU`C0#Qt1=hidy%vUD% zUSJ^G{-T?(!n~9eq*H5|>xnZ6v@h+P8u-p})46KYG&GnURzk*^b157Kd&_|9!u?C` zlsN%?gBJq`PE^{r4H%iL5#J8(RJeIFyVdgQDm{APPdF6RzJQ9zvTZY9tlUX&-fUqP zM79#^r}14>!lpalZL9#`1pibm%qyXVk6~r;lWeOMM$3=OaKxpxiKaRGc5tzClw7eG=G8;}`yMj7&_|nF-e_8f|ED zdRM&AOzv4{KvM~H9WDxn5dl90l`qiA2ET-7S0Y5!=po*)V@DsWiDHXjCI84789-<9 zx}bsCUMzFh=#!~{a*O(#5RckhyWoRhlDx3AbU0MeDSM0seD`gB8`$Pf)-!mmFnBrt zN1n;Q+a=KW3yViM0){UxNb3-nuTVpGu-IkV=A755K2c~rp!h}04tVcVH#;#I$;r#J z%GGx4+^0Iy4;YWyn~NI~$}&5xA>r~$omMpeK1VIKf97Rq0d)JKDZ{a|KqsEtzdonm z{scg_M7g}74JBfyb=#o5_tkf3%Z@Z2dAr_CuYT-Raxyj*l*9>;+DH)W=2kZ2X~4IE zpL;7?6o{54b9fCK@%`v2ns6stWoQ$x8u6~_1?fQxywy^^c#~Wt`jSi1@OJ=jid`Wq zA%T(|$}?zew#i~qnw+;m5E^X*Ux@4W95_Zp);oR2E7psKt} z3)#w?CUcX`mf|ywr=M^t6O>FSS+)q6eTJrl_|ix=oC=7(wfn*pz)F{4s*Gq5{PN0BKT;$D9_v2c`z-zhG>4A_{|M;sAiM0bpMp`pBonL32bBY)y@8gF(Zg zBS&6wf3(Six~uq9%>fj^*0&VW+sAu~kooQ*{GAe1V|u@S{n%$0?fe~NbLkd)XitDP zcKGmNgxVCYP#-zV7jZ2R9`gyR!q{OMu4X&pR(`oICfR z$vi}!137pv3>)D6cabsWm4Fp0a@pt8N}&yVh*S_WHi9)x|u#Z4$+FaQNq^(pfTw0fXtwx@)_tm z_=>+d&kae2qeoXD3;|&gHrNb>CqT3Z*y}z!QO#8Cs<$uNhjBDNiH=623cmrn1RDBi zKDB%KplgN31f6UTIvjMe7(0!jX91up8kg_*N6g^Bh%nBx7@3pk9)dA2cU_CG)2 zFeJ43Fc36YIlu(VziP&31yCLg(hig3%Fqp|u1>_{i;HWZ1)CKLRDeJ$bGD#c(%9Kq zVBWMJRV?lUF#y6K2MHz-B9vPI3TUl8)N`6jhH{o^1d;sKdnTwtgn49za$GzSE`6B`?SjRVtT>u-@{xWVT{F^qe* z%`YBFqO~yB!Gp+5uNQPf${vnf=v1=@9@E0D(M}aUcI-M9AdC+M1a3mF7lAn9q*L0N z{gc)U=t7+ISe!vdV2f`7h#oa9YXI2PaCb#O1lf7Hy=(=-Qk^BlU;@^5v|ib4KDf3J z+u7bxgS^xoy>e8j)gL|(P6H!;C6E^fWNE~}8`S~i*6=&{l+e@D2aH(!kASM0gWDZV z^ooF4Iur_0;`|E^&`Sq6^znaUE>k9&nr*1*JA)%5;eCxd5-0aSnq2Tph@s&&wlp`3 zGH*Ni&3JWX88;k@Sm3!5>L%TzPKc)?RQD8~aJII$<1&I59CZ!N1aW*a>+_~aSim#} zAsA;;Iqx_o2+*c@3?FqF?n8%$fR(>yVZp8tT-Og*l}{$eQ?;@=j~?At$?xU0iWQ6+ z9U0^tn{~fz?0MAT#+cnP+=Vyb@a!Tu{NNw(=@&Cp!>Q;x3Lp@z(aHiZGPxDXPU4sp zz-l$~Qq2f+Lt)_w0q3VhUqRg55n8;AAh4A-rIA0RrR z-+_53%6v#0<5vL9A%f!b3o0WF^#A;cXxq~rvE^1CoXSCU%SCTbr}No3(I{6NXTG&?#Kl7n1-hEa zyV+sEa<0W5DoZDuLgnJ|GGjM&fY2HrPj52sXcJ{tk&=>vWlX1am^=serKIF4A{4#{ z6sS?D9zWh3RmylP2wC`k5*1qMh!@C@{3;dXn}oI*uhB~#9)P}3hZ0Ja%Nh%^aNUC-!uKXnO=e}G(ps6+gj z9Wv&#ccX5_M0MTVdlS;<^E+?DoeLie{Jc4cHUvr*q7?+X0s2iy8Hn5zylOmP zDg?Uvb+qX!r5Ntc&diLE%Ov^u0zE$JJ7zu2Dk&+Ei*I5vH7>ZsVdAX`4y36PG{UWu z($}5*rM0V)|IWUSaI8Q49nNG#ujpCmmYlrWK{^#tbv?2~(TQtXh6<>Er@3Ry)fm45 z60WVUKPF)^-|B`!H{gdgYs?A=2nY+nQ1*vjDi(%4F+5vPYeCk=3rq;O#8N@w9cMuZ z&xNtbNvASyT$esO8ZO=Ak*imn7+C}N?A@z{+!KWMt}dp%djaWTxC=W0iR<6bAPvCP zE8^UzL)@qs_bK?MJbLipoV4B;az( z|D1){8|HgALpTKq8!&L7r6Edi(?7ruC20* zoegBK(9pm}g!UrN*=>FYa9EN{JuqX?%QUWf%}V#X;pUA_uU9Cm?;7#aQv)lN{_su9 zAmH@KH|?0Ah2WGv1(sun#(D*UV~3;0vB)Fie)2|X(ne^#L*^2!%(Pk$s@N^NI1!8o zi30qn9wTNA6^7hLjhE^4w|9pg91KQF!h7%SqcTPzjhzk*CAPeEKq}Ck>^@i$a>7q(PcCHU z=7xH}Ths+;9gv6CT?NF!I1`1pD4w%#ALwnk@`G)Aef=*W%nWS=1|3|foY9HyQN7uoHs6Qr>G;74PL3EXvPvlAb*Rz6`FEG zLzVzis0jO!&}Ogw^yw5R4r@C8=*DB+#c@BVio1BIL>u6w$1_LX3~9ehdPJ_fhRU-pS9&1%yy6} zl`Re>6O|7qX9CXUq@<55FW?V?JDZYXjp7v96Y^t}O*m;rYB8ijHH22t^76QsB+{^m zdQElpJEW=bX{vIpiv38E0EU1XDLMg1e%>fRVuD<<@_M9R%adHBvdDZecODzF0oAu% z&k6S#$M@Tu96%OabD1&j!&Su;MosPdhMQamch;h4EkagAAQ9WzfU3aHpkO%nmoeRe zWAo3Tvyf+iC@wIzIFnImqH}<}ygPY65kc}cHQ8!d{onliKkh53+q7{7u0ghie%$^K zw&a?k;n;)9nVA{*vLSoRdYa2{7n&Ey1wOPBbYv)JBILoC7QlZ3GY{14fdk%pg{E)UIU#pv9~3YkdYgV5{Q``TQeZ^|Wyn(o@6v{= zKOcdkoC`sL!qI@jvi}AKPymhnfzsstC^1a5el^52^w|OYTm*keTx`sAL8tg8`gyJe zth6_Y(I@PG6KNZ;8vQpGPskSm!CVpAS7BEI>2PP4I@YbJ~-kDCTBZPQb3kIsCvg~ z>FXn+dt4^EY%@TJHB_yMHQYt7JIl(D)(DMf)IOCd zDXo_9cY)aE+1zsne2o|gN#KnUtIkp|2#(_pk{A4xxl`nmO&(Cc@S{mOxi|-M`zPzr@-Q^ozMv7FfEC-7dhBj#N3&RaXqs)gvMqeqTjYBfKMY@>{jWBi1SU_SL` zVs}BDNw&a2IvEUIkO(mFK`6rxm}lF3M(rZbcqqa_b}T_@(*OFcm`ya5Kr2IW^7`C} z7UWaxL<@_G01Q2W6e_=$(omFRXDtSu>Zo&38TIz|!sHdIdN?rd{ruMrr|`5uKU^`s zDMR4G-DhvZ;9op+_wBmg@@p)~&xSN<-7BlvU|7iVt8hYp4=I_4|;r1Ez>^%eEy zppISYUV65$d~8lG$cH4HLO<;);r!>=Nr9!Ph7{oGOP7~DO)G$?hw!NT%3qgdvHLiw z{z~0$g75_yujoaygOryAGF|#I|LV_UQ_iqx{4bwa3{W6bky0`9T^7~6=Vr{Y%#~!Uo zen{n9yy)+*So(Scg`9W)?SHR-Pt1oWVA(vW z6>O7ToA*e_m~Z-~`K6_08Z#mmNzdYv61}i~%nQNf;ML9N=i#wPIk(nlq^g*bqR4w; z$e#}%M}RLBgL5E}0I5WA1{xbkEkl5iHP9?2`M*Dg!A#x@oFK?R(MEG0wK6!TsMst> zfHcTZkKN8J$`9oQ1o;4mf8W$-Vo}JW&N%qPTRMs+uAbnsd8F*q3+^bRkpi^LBY#9! zb^=xRHB=#>5XW&}(P1YpO!-959=<^xvWK9X(&2Udl;1z)=~@Evf(?C9gHRJIU&i(g z=D^A6E5k0cb}Jcf;;z_PO}dtNmKjj>*`L@vmUo#KbE&*M5$0`yY8J!`-;&(F_D2iu z6^7@jj504%jy=8Fi4kH6`~s_A6H=iRDZ^2z`vsM${VM4n_;wSz_!P$L)_z-WiuQ;M(Tcgf zwBh!xhJgoxfdfrzmxx4-rPi&j#CgHpy_iD*In;*vyxiP3N#TzlLo?+=k}&Z)QXAJ^ zheHABhFtu02ot-Wb?ia=7lkSsAOBYCK6(;9bwBi5;HDOZxkXQ*13DE1)EuJO=!PaL zJr9nEIN8>^wn}_RceZT8cLDJaWy~K!{`P8J7`SuENkj8XfIJdvxZp8mra_wk3H;he zXhF*hMn4&w=HWu(k0{vKad98O3LFnAze*V@Ez{JsOZuK3_b5+L zR?UaB;bhj@dNN-3gnI2s|1bNE;{Wk~abNHC3}5QepB)-d2p3n$d`9)<-xF=r)M!%~ zwKqh}&(FBMI}pREU_AHpPtwka%A_-i{Qv&ueF(R&CT;3EFR3Fnf(?B4${=n|COj)7?`j#nV2RW6KjYVaAuY zv)`>Zsvo#iFV|@uLRs3wc5zwy$mbM>+VB6~NRI9m=80T0OFukfB7X4lz1_D z61(E)v~%q@Es6Ju%-i!l?`s@Ty(a!ez)q#*PV=ApV)X21-t^f0d{|Ig#ilOnOt+k# zb|6`{ww{0Uop92-J?Hs!r~E3v=Iponxlt?ZqQ=aO-yI{hXIItf6wKquOg7fEk+ws{ z=dFvZ(k=x1@V(BR->A1zkzHH7E;nA}W#B*alUqB^A%iTb43a=J1)a zE|E_sDw@|1zY35Nq;7e2ko#O6%c7%`*XtsYFXI<3eO6&`HI<9&o7=lBc_CiQuq1lo zF4-yS+C-b_^TeOkdM&aV8q{xO)Hq}d7a=V{Jo4`S_SIC0Fd9^<{=-|>G#yRnXY`9B zne=@-+swSnyH1{jha6Y-y>wYlA{onN=610#8a>xfQOGz%eeOs5wQ~xbmHWao2JGI zGo>6AKL6}zKSx_V+*Au>*p;a!r^ZjLkXd;4-dEEr{KL~K>`W)CnDJqop3#^vu&o>E8odAw)K7XrnN0!euPi4;&I%)?L#ocv){sQqV+oI_q-2T z{EJ&+B)>%0l#!bqx%z%6DXw8xZ$|boHT9F@yJ*#wKId-T*c#CLxRb*wd}6V+;(cL@YTep(7()jkk@;P7bk8gRSf-;s&%IcPy77OR1zp<&zk;S+w zZ0=WWOw2rz;d%Bc>+8L{YfVBp`>X1sEZLTok~#ch_f_iNY&u-0ZWAuXqrdm5LP7zIJxem(*Y7HB!eK6)f>+=FF8U zj-pO(4)%v~JfBluo_O>@`gW_G$4?=4K?N%Ioz@@!s=F!P>8li8T~hXZB5jiR%#^+> z>%v2|*8q6o0lHDHZAdNmFfN_(jp(b>p6ov6quYwJ%jQ3KZn?TEMdn^q_0o(P)rTzS zWcA1^Z^Qj!_FNTGkfHK2w6pBrZjqGKDdBl3tCHW$zr@~Mx9se4hLGKhw$?AK?#3Ga zz}Fu9PsL45(6p|JvCBJXRR5Cs^h)$y`qh-$w_OP*KP0^zJ}A`0?PESnODW^^DfN_R zZk{Bsg8Gd+aTedLpV%&s7ZonO+m;+0Wt{TN_%O5C;BQ)8#)-_eFF?_q6(`bih^Kd= zz~z3uUWJs@@a$_QCx>mPzK(AxbIMl@K2wF=Wjv5s`g*rxkC?ThMeOWfKg7V9r;qZ& zY1Fw>T0MVDw*QO}YBD`8AIw3gWob!0Ceh@iqgC9r&t+Nr68fM&H!|t{msiE@);`F^ zV5$mbj4#*x-1zk8w!5jg*-k7ENw70#b9MQUdd4DN6trfTu#ZKKcJFtU5FHoS8_%w; zwK3?})5!TS`=UzbUK0hXpy$?+;A9`ZuvAy1U=eJ=Pf(YCYoU+f$lRyD#ws zt(XPpFRitWMse@8ujbL1#i6Arz2(y~FWA4jdE1T;FLW~xygYS|Zfnub2WrtFOd|GW z3tN}`vcmhnU#C~z#jf$>s)f0eL$)+y%vE}sjg^%#|L%@m9kG92x3MUsa46I?Pi%RV zV3EYRDBhR4*=@J15`#!sL8}A1!aL_&(V--+f6wW*{P4oNImLJ>XN#sRE9VPE0K{F3+%?LmaWI@=knj5zbwG`l8&pP>E94LUsjff z!C#>EI;YsIW#aQk&G}#YxQzZ?pl@vzi)%V|6cq27|D|vL=YFcZyoZ89?Nw4Zrw@k@ z@3{la&4RI(+CK%C`_B&0cGJrQ&+N*PzH(PXnTmZ2yGJqK@dUv?ABk&mtAfbJWOP|) zC-yp-d5QKO;*02Kc*auTJ`~$BmR6ck$w`)|yk1zL)Ad_5gC{Svlh*&s7gkaG57XB# zzOrDmRR|7?u+;ROGt%MjP2!E%n?AmMV{ZLsn3Y9DTJd}ore+sZ{-U=n^7S+$?N|S; ze#beEuU#veGepY1|3O>y`B|+zj{91>QWYk>{)%K2iTquPU+%r7eXKUkLi@DNgkkF8 zP9xp(mIB=BXH|cf^R(Z(EwMG#M8T&rxQc`2#Hgl8rg>q#$zxvPbOEck@hi8xo_|`h0T%YwFz{7(o}s@Nl#Bo`iHspSYB|GqB%d`x4!-{pYDnDULN(NB4$yqS_3DZ7poSd+qSYt z3VZh?QL&$Jy%U?!*1ZFjkoW%jNW%*nO>OW?%3Tb5LE0}=VaRlhZfynprb#{&wHOu@ zdWXV>_Q^EwI7^sbuiXW+Mz)Kz{_ejYLgnyi)w5_5=7kd5vB99L>>ht-%HytwEzyfz z)iU8={OrMfhrOwDkLB1_mkrt~Nk%u)7QPLxe3+Bs=NQudD|632vB^};W`>ryUgc)D zj4JGxE~ELO7HW8b^PvEn_`u)@?OO*wLFJ*r?L1i_&yAHb46d%c;n<>Ju#4M=B(vbG zt*H{O71p!z!gcb-i$uwnNndkhbnfUW%h7^XktuK3#Kq4q{`vdY{s}L&m;A}) z*mN?>PMMH^JoE4xrDLtSGr~E4^z>VeL!yMedC!F>=uHw|fr5g4qc}Ee_>JaIpPf&O zrU>OMGzk9jC4F`lll*bWYJ)E+iyRAODetnx>JGB4e^B}whymo&{; z>WftLXR4mBn0Of6xPv2jr;O64`(2lQv^mjV)t-4om&o39G|&D`Fvs^)fAN4-jkS*w zPt@`B{TX}xvpxy!DUrGNj`ybe-`A`|d9#M;99t%@n;%q&ds!YXPUi4oh=DD2ti?Ol z`RctIZ&Xgl)i*ZpRyX14v3P9pZmoceuUwtpZ1*7R+4~UP2UQ%%X5+H{NV9Y(Q`Q;> z*CzzB(!H(;_Do3FqOm(^Zjw#oHQVOPa^4|P#yJeH?$A7`=k;>mEmmuBPoPftLC6!E zfFYK(l~k|f3I)UJLP`peJdE2Rn2?$F{X#)o^}-uAN%4uKc-gmpW6U{b9ZyfMo}3u^ zT1`3nB`;S}Lsz#L^`%vzGVxUz4g;xtHQ_(WRO4;2|KI-<(04arY@RSNX(f}#C!HlG z_d)F@Kj5HI2t=1i@2rK-1q-(2x`SF;M&_%_sx@I8)?T|U<&PY$oId=tE*G;~RWW%G zXk}%q9Nw6P2r5sz64N(4(D8$L#leGB=7Jwo+~M{HCcdJgBJfqTh40CK>&f6kdj0Y9 z0yd>U`f0u<>U2=6thKejI+aedu%E%465MU5OO5=62O1vbvp24sP6oy$_9P5)xarV! z3GvQ(`&Pd_E={H-2ddiNKytCLumDysDQR(e1tUiSg@Z9vQ;Nw(AYq)r9Ax$U(}(Uc zHT#JZCjd|YV*^%So-%kV(1bu6>#<>LQM3&?&pq?`r0>!nifcq5;8U#D z5{+-BJ4+fBgI@wr1fca@8TK+VLHP!|NDy!%c1+ah`SdY@6ZjRY!A-Z=&*#-iCa*$4 zH7q2A+kd|gH|R9ZrwEl3VC;Z{*)68kA_yNi3pjZ(<;y6y2*Qh09JE6?omE$#4Jhof zfnt~aiP;K=q$I^)4h3~lW^f0yyGDT@Q&%@Ih`|i)GH=i{FjE4@k8>zAr|$-M^R|>I zb6d5Vh7e3fyrn$?c~=#RJn%yEOF(ioYB==S<>%$K+vEjy3o2@8u)i-inmtZA*5n60 zKR}DluRD?6p6s>ll3w?~^a4=GXwhsQL(>z$Aj})RDV>UXZ1PS*si?#G7z0hA?$L;P z;SpNO_3HwwcR@{J;_Ly+zxI^9>H3u~MuAPx(160PekhX6BueiJseSyvPvd*8MgV(a z-n3@STz+0H%0dXedYB4;p?$U`Cmp6xXwbTv3)%Lbf`$QP(dv!wIzWR4sB6B<8T*1S zZ()bJFQ)iT;B13&BX6FINf#PU?X9f}!EpCk1NCOadsFekc%2DkX5SOE_uqL86GQeub&|d)}f^YQs#KsYAJb~5XN6!S=w_% zI|rRDh!0@G^f9X#D9+8S&gF*uRO~C`t1IKFz+7R{#i5{GYu9iwtbqI#QU<0WA0Ky? zEz|P`_|z~{x}c}WLPLpNTK1=k4?>OGnDh#I#vls<`;EgSs+X@`X=rK=H3jNVQ1caM z;9?Tm%8NC?jPaGsG@}nIsGNP77t^AZY{{kUn(dioWnO3s!uS36@Gx4E0InBff^}dA zeDE$~dXHy-LkqVqSSZ}UL`$}J>CcZqaEnd>4ZG+^ZfcqYEbwI1=d)}cUkv$+XPVF9 zy`Urng})Yz;GsGRpeym+o;}O=TwVl8X~|texrGo00xd7yV{ASjr|doog)`C@N3Be= z3lb9t2w6C2LSEE>k{?E`a7A(3d^dP4zI^*|R}=UZb3saP>E)}ChCr0!rh}S(kmDtw z&7f5PaI3YgEn9+wci)FuMz$d!6|#5-e*UBhw-P`rnxL#MI&#x}#WIDTAJn>lVC;7+ z+Ri5NH`Uf-7*ecrSdhLXPbf*3tpaH006jxY797D-!E;u+%9bN~Y}y5{5ma-Gyuh@D z%jeWnuZ5@RdnP5ho{EO83`V9Vzl zORi*+$KhF7kdy>S&7VZC#Y8XIbAe-dUc-S9-D^*8r=D;vA&99XBM8Tga~EzD^>}Ls z!-f;#flHBaF9Wo>-gJ62k9^ce6cb7~MsO=M;|;HBmBRc9*9a#AF$WZ6M3Z?51GyQ1 zx(ZjCXaPO~ix2B_9G_C<`SDi;1q--W!((FsKHi>XfdOJRXd+Pk?l$C~nV%G@9BBaL?9Nzgis3JNmfX8?$%Y-PCYZC)N$#9=Wpt<_qjEfQcq2I>CmVRm;k z9ug7BOGz1lDyr+iEYLML3_eAQ3E(j0;Te2?e;4$%@P=5t_>0N_eng;>c4vD)02PAi zEX-J;Rmt(8nQJQ{us0n+{$j?jaq;5c-)ALVfsiVlO|_L}i~$p6wPE#d15R^MW{i1P z2EY{p2mnfoSAfzPAcVrkAhtP*3!LIw_xbZmMAIlHmAO=!h2TB_=MBd98QdDZXx_?0xBUsyD;{aT;=uxsf z*9FiJd8v0EvwmL^XuGSO4W4gd$ztpNg9$TUjBK#)mX^vu0m~6&KclQ%zuj}fTXk-) z=A}z+CA|oyhQLC~+(Txu3ZpkG3yaC6W!S-EdW5GF>>c2p4l0BrUdQmWk5S(*UE0Kk+S}RCD2YfDLfQx17jKOA1gIVU~USu^u zd=SrQTRlqTIFf%Vw$F`|%gmeN5Wxt6Pvj7;7>C;c53Ch-I<^tge>n(`5Q4#FFboOg zB*GrlQDV6X=CLPS9SLqLbj%RWh+_rgdLL3ytrf5pkmYE`l%0s;3c){TT~|lRE2#(; z(D++zT8mN??C<&O@1`AiQVR@Q$m7Q^%2M9Ev4FV?4^O6cDl#uPMCk%@0c|sAkOQNS zVGox63mz5h+M?Lq+ptIvDlisQH{!Es>*!oc*Jr6s#Nvf$+e)7wo1CVoVz8jcqP@rW zhjHAPD+AdCCC4_&t=XNCF)`^D+Tfd*bU;3|Br%^u0nnv2=#;>XC2T_Iul|?>`Ciq| z0APBImoPzcEYcfa#CaR@_%Rfz?i=!VfSCjwHAJGsGf|m98PWWPh@wiO0P{8xkZp5v z%)`S%LuYY36RRj@LPJk43Qq?eZqj+?*9Gu^al!plmRc;r+d4XaPdY=+6WC>x)cm;9 zeTvuGkI-X8Mx4C3vUe&BF*e$RGc7U36E=kD?eHUg{Os9lE?1d*z%cQ7PCg(MC|fNZ zih978Q^CA9-!8UPyU~!J@COQhiO7r$x2gLEvT74?3uNE`vI8a*8Xahn*L78EY;3F) z==Oc2dk-GCfr|_j`)$>~x7X_IiRbo}ZrWfuo#%dO3z;RoB1fFgxCUf$On5k6Cu!cG zp}zhI45XEmpkj)=Q{H6KP_UhHD?8wMLds0+3{azUy|&@$=RxI)?X0I5cy!#mpX&fv z1tUvcIGoJCMAnuoGydY3p4WaY9Y{MhAoXp*rU>I z0AWF(Fe5+z4+@Ny21(@h#GxDi;34LQje@B~MQ)8NxF=m*$=v?+rqmnPZ#cbGkK@jL z?KkZlBr2@u;NW2S)7i zJ*B6E6$7zP>2=R}>%2`nwy@7|_E$={*a5JXB=q;_G;isHn)2P~FE{Nki3pqZY1ChbCr+FYdMJ@67zPHTUHIQ&aany=! z1+8_X?VZ+dp&SDSQY=ukFJA0J)(&$fi05M`>jT}0xhj%c9}@ltk#nrOKqTk$z3y&Z zS+&nGlH;?kNXqsvVZc5?D8|3-vRP_=S{m}~8OUY;umbZv8}Ax%6A%Cp=zf23NwpT% z0aqLshfr*bt(6Ci4oJf*SK=jv6r5hQUX!5$A5-TBuc?)n=r zktx`Z_ANkiA5wJ4nNT~?$@1{>y5mfhvpI28K^@x!UWs-6jF+1;G%UgsVxEVEy@iA^ zl<%Se^fi-Q>crx#gzx~ngha1tcJAA^NlA?N%K;?B1NROO_$}h|thiGwzb20po{WbM zHD9^~nzYBeZw?kAy)7+2CY@2A;p7GcbqIrhw!hytgRw;vn{{!_)8tY^$2*h<+kDHMhZn#ro&d&CfsuBDL?u_{iX( znB&N^q$HfZ!5jecLy8t>pOCr##|8%RqMxDuuX(0>`=1N=rQrwFTt?5l1~DH~~sD%((-Yd-AXU zm~N0vXKU3-1^jhxW(I2ux9{r55CnKE6Y0Phnb=0~bqLNb_oL=3dhhxLClCfs{vZ0=g4l9rw{SeiCWX>hhqfS5ls^t+_c4_i~8jLYFi0 z$rAyHi=pHj#sZH4qXupn2tQQfH9+W1K%5)67s4}_MD<&*mQ(533$`OXJT85{dkA|( zWTRkf5qML3dl-yBTO0Kx);&>oq2yABQ(Vg+s>AQdL+Foa3!>8?eW83N3Dv}p`6X&< zYLG$#S;bFSXB|*bM|z8n!>Fy78;WG$BSCP%A_o%_U~ia^Bb$Mq&d|+%=55tDH6hm` z1PTijTHz!LmIh(?O}r&=`)DkoUL5NwI(+C5dMIxKfp#oF9tB|l=vKZ21Q!@IsJ;d* zu{?*jM@}*oe5=s84QUm7rgvc%X0jK&4?K%8$kIO(*7B8UT z+uaHPn1^1cEu>!n)~EC~0vZz=rI{)vE)G33f~JUz5S~z*kkLzmXKk`W#hp!9WT)LOK^`7?q^d?Ml88}Atq7qPPN8!a*zW*l(7h23Uq&cxV3g_yh~QBm`|C$I^BVeX}L#3oT(Cmd>{6 zgeM&2(;jWH=-w@85H6T3CVO6poL#mPJ0 z!R_1+uj7A%=xs&b>+L;a<_#Hqek)GjT5|(wzrL}x@=D6jhm3P* zyL|+4Q1fe|)^WXJrVZXwe3SF5&SUX(&nI?X&EhpPH4U)4!=d0k_2L*u9kR@+a!@(Z zb)gk+0SSobssIBGK2m(KF^In9`m?7Ga&SDH$YzRtD2;Uu_fqa;_^ND(c`R!GdIxK; z=sOlv5O$(+T9xM!t<1{=9#l>3y_XyZVZ6m<&*l4%+ns)L`)Qsw*v4r0v(fUrtB_z0 zV~EKh-{!fkYZym{fmv5uYxB^kD54R8Sgm>r(iH9>Q91B#DX*^^o`rK~Vb9okZu?h> z3UU>{`NXYV9bUyBIHPxZLwfA7w#HFurWs}8+_c*TlZV30Sja0J;^SKt6i7vH?y`{o zdXg2&leKgqGAXHDB(t!sNltWCQMy&*>E>A*D>(t1GY3NsXANp9g!H+ieuAZdm4Ray8A4&MzmcNtV9w>aI$0VOyi3Mr+flwYe;Xu}>T~F{$JX=IFV1u(F*_ z+ag?8hBn69<(oNU=7SxlljRjQoIdQ4n3NREfoUVYVgCJ^EF}z35_TM58RJl3G)f5N zPzZWx@olcgM5}b@d60?adCP<{=0dY?Ub3IuUDwuv;@X>JJ}%aENN7#uyYW-{uDGC` z8Lzt{@#_f`bQ>ws>T7};9kx7_k0SoB`#XFg{&Ufj`ULnJ@t+j`|NV6Ap~E@~e(gTH z%Nngr$FdSE+U=)m2DYDK>|`yx*%)nPu_Q2{3Iy%FX&T+asiEjt^rt-zphm~89w1M7?bE|>&V=8PFpiW+d z)-ck5v)LT~{__2c*LO853^Pz*q#?SySzenEj2*#$HQbk5IyMlWUYofO7RR_#YUh!PHG z3(Z|~x;HnAdiYBJ82@|~$Bk3$Oz|M2N2zfg%Fhqzk2Fh41ye_G(@Jehm$j7ex6F2!C5Y@$Ti93;SHo z^D|(W5@W5Qz1GMr96P9|fr8TjZpP4hw+QbV^&Cbw9w#=;EKp-|FIKr9*H6y_yVWum z^pWv5DJhZSZf|*~i{ePN15)KjJIE@*nuKj=rR4n{M@ zaok?>PF_eJgzjD?f?xTw2fZ8#5%)M0vUrDEpoUOQDCtCsfe{$*)j%h+as6%R;-N%` zvJJj3W|{!XFP9`GAy0=uhLeK>6g`%Im)9C&@Pm^MS*K~X7&Z}&%TRite<16X-Y|gdAD~&7*pp=Pd&UsL-@X-QKZANoQPIEEa<>dxq+Wme z{IHZzCn60+Lkb)+Xq7@6Rb0HsG+S1#g>Ptj0U{7!#qM|NyHS9a5o!?);=CRm98^GY z04pCzxHv~3R{*RU-5^lw#l!$|?wyFCwYyS`nFc0lU%A}Iq0hK7XS+J5c$QGxNj6wk zToq^GK6-QvhB3T{55px#*~$v49J)E?K?xR%rw`5e9H!d}-x;momk=DQNnb-7b?=X3 z)ZPt#`7cbvjvADX;Fz)(j|cw?MczxMr{r-mS~f@+|9X}sSlpH}gdh&mG>pPp1qq*W z$kv@ge+$JPs>;f~%LO<3n_r!qHA%|NbwEajDj|#4?Nl`05L?d;$u#hEC!I>ctb&9E zk<4O|8vLy?AX?_&R6!yDS}v~uNy#A^@ZmzCAYMV`GrN zK&z`A=_+yjEOXyTB_z>EXg6rtQjq@1DY?gz!h3X+qpV3&keZUiXn5ZFVb3X4Y* zA{ZDAAw!y*oh3Tem&z73@LN;U1kVDrS<$CM4=@=94=}#K>&N3)lx3aFklM9*SwmLm>bSa11Q2V$sPEVrlp+ zv$#H>-uxaL^F#(rln#-Rk*I^rSL|;MA4DjPi!+`X*seP)@=4 z45t-=m*05*VtB&fku>x&MVVO(Yv9I-e$qmycke+T;NC%Sw#?1qW*}mX@}VSwROD0+ z4!ON|2$w2}7+bW|y+13csH}7kNbJ!fiEu*E_S&^n4h8IcbPVw8@C5X*WT^Tai_l?6 z=l+Q5^^K|`%x_WTcAW;!blAqH5eHj ztV0rjpis6vYKxf;bd%?Yp6X+Zq2!oxa12jS4i|)SS8ROz0da9XjaFzJqwO%<4rib5 z2oI=R(!4YwfCKYKyb8SV(2fLS!sIwg6WAMod?GE<@%5|g#J37uLCD*j4SbEYN=Q+l zevKBrzJIC}RvIJIF{{O%GMIIwuP&yywP}`QASS>g1B|6ks8CQD2}cAni|worTPU9u z!WTf%*A&AxNYCZN&`);vt{^R3{j!T&xdjJUG|Vy)I`Lj+&kHE^05VXB65gnJUE_trX$SyK zWPk^eo&>y3pm8;rMt#G6hMxgN$KqTJ9o>csq8MN4H&~G+Z=>`krX@p{%90w`LHEby zLB}1-7{N=1N+B4vv4=-VT|dN|Kh@N9-~dOX(6(sw?LR8cnd2`YIsLSb__7EQ!)`O< z&lg1t@e$2{Hz>5{fn+*j4@6?ot#?(-HK`3Y8%oLZRB4@}ek zJ?e1~^$)6iKtT*m)REU~Lu4FvO+M*d>~(DAq!8wUd@79 z3P%_6nN>PAG5!w-MM6hxSnBz{?psvvu6bL!W1Alu7jPsOXKi@#1jFsR{^kaI0#l$v zmOxmw^_-3X=k!S}!VCT?A|fIN;};a}=efUp{mMwS5yPohl8Q3pkJtJ|BJt*2XH2uT zZ1>v;N1Tp*Tq^!p;-;B3me$|}qU(kG0_g=~<2cWnfobekJm4`(N@O6E0GTLfVW@{mbHMJ+@vAVu;Z|I~3L2#}dj$g|P2wH?R80hzK5Q$5@n$V{=Jag|my)=XyF>I?BG^ncz+ zQW=Nz6ZMT$%6&`oSKNw_!K21yK1?{h^r+Gdws3Jv)mEC6l?X35EK-nfipR$sp76VU^Gy#j38rbsb- zY4BPPO^JR^CiZsuQ-B0K@Mmec#o#PMRso?K%B>z-CNSGT4;>E)IB^UQz{JaAWjdw- z`XLgIsh5arcIz@}wEs44$YMMhn2Q5TTuRSCbF^EWv3Ob+`uH)4`bp}3B$~uD5iNZ* zzHxM7$_ByuZGPao?KjXAQf}Fm25A`>nn4EpDiUBcPa&@f20nVz@Wep&Ss`)hc18vkI5$xQ>TEK@lb^z&F`{sP7V%GJ;TEp zOOgu-V^j!51=5{ugHHhZ)#xf>Cka7so`Hswiz^)ppFryzK5XUFwa(_M!Jl_~kQUBf zQ*HqSWqu1KW-<>+ahRo`d9|nxEjeT$pluykUH?oeq3P^eZ;zCX)a^~uS4?8DSEQG& zYDGhh38NaMpYWBz(aKn#Bo`JLS&XOyiKt14M1Uu61~7(;%fhdEB4%e{*2aEj9|?p( zZ{9kuIWsTJ1rgP6qJ$w(3Oi+h@IqrJDM>;$8Q3ixVSL3C?<7}#NDh(p5Lsdn$@x46 z&mnhbof|?IeqaC~6SVoz+LI@9XLqI~p^kzQS6*ofq2&xg5|H;+zgvRoO!!ETXIMcR zhIn8iE&-_~VAQs-gW&H47!i!zhWXAyk&fti;*O$&ib{duk=*hj|G3AO`%IB>J*>I$ zZ3l*&5fMy-QX-5i)Q{wu;3bGt^<>?4;mUZkas!9QaGJnSVzee`#6Zed^!^*DSEpaLM{ z(GR0M!tHrY-O&RmycPZmaDBV6nUNU(Eh<|oYHn6X6C4w83@p(=;@*dp5HSU`a)58{ zyPsBp;4#v@mF18NpFs8B{QJ(7D+)!ftsZTguHMmVU*1veABMx!xbKs?+WV3Dg0 zQgmx~_urCn4_mYiL6m&^Rx_-hi01JNh0%{6XJGDvMe8|mlKdRmAC3StS}~E8knv{y zW~zoJD2!fOTO_0pV9b6%nY>?hdvTjIuCWgZGhw)`+5x~)x~RI!6>{|!wjYJW6|A<} zEHvTcCI8yIfD}=7r!IK;eA^$IDtKd60_-=?-%tDLGsaGUp*a4m2_m?n5yF5|4jeXE z_z2aA+3Gpy8W5JB=$p$>p^4_p=gxlMKs*}giHTfjD+K8p7ZhBgO_Pp`Z9ihrHN^;skj*<5pk5r_@1a#kXWbPLP_E z;Zhcv-I10q0@=WL!ht*2LacFLxVQvg1xrX<+E`EU;6Ai`fq^KDys%%HDv!# zXT+w5P9iPkX|&pbheK7cdF9<2`qoV9;^|~afXb@feIyJUG)!)LHx)w@YSi&;cE&x- zf(a0Zx8ujGQ#hxI_n(6~5TPH0`8SCQA(B=p|DoIuLZINNE++SMvd5k)Z3qP50EToP z;h}Nd8`|jyw|`@Ln}kUY93MNml%4hqV|WCiPyP+fq87~o?#)iX1%<95fSD6EOIyiY+L1nXl| z2=S*n2c~-(jnk;TD-8K{CpUbFk2AtR$-c0Y%dYjh64A9{ba|W{&A(nNY;gWz!%K`f z5V4Hd4lvX=enhd1t3*s#aL(~Fz^FF5-n0l{8BAhfD7%6YsdVs5Oxkf=;IdrS(V-<) zcuepNiYsy>wezWHLWdx^L*h2(2gfBORu;oHep|G(wQavt&|6F?hSo6Dm$H~2NXU8H4%6gtXZn5)avrQw~DP@vtGf~<@Jn-&|_je;8U z9q9H0$H!ddM;vezS8P$LiMS_(wpvCBS!r_9O}Un|&&Wm818(S{XUegzTnO0?cE{ei z%;@$Lu{=}m_YVsxwl6gf7BrVXIWT}Mg&2|_sL2|!KXKspm-6NC1dlrslf;XYtl~0H z@k9OuCDHwNz7*dIA76DrIjU}f1T7)w8T$Lm7n|F7041_nBkmFUIdz{ zt7hX8@jvp%mR1|A$yG-`j>UOX+>^OT|JZmpH8Lw%p$D3ICOS8QtgCAkLPQ;R4md@= zzmFb5>U+zh!TW3dxzF)hS{wVyKi;@y=SHga8>MDS%a3N1)UQ^&U-DZQY`h_ZgPguu zG+V&QbZq2~sF`H(yjwaGiz3+IGC6VWpt=nx;EP32r+c5@mG{4g4l_ev7x zc1MTz68FiF5$8a{8cheu*bo@xYj^B_@2A%ZMH?vim(lHM7{IvtX@^9!I-TtQ#ooI| zW4*3%zt)@`%#JFRh*T=6kSU}nr2{FED2dLYLefDgO`VlYN-9YRnL?6~Q)eNSB$Y!) zMUoO7DE8-mthL^~$Jl#}ckFlXfA-k^SgTn)p6B2)9mKcXvO8wn=~$;rhp-Ok3u zIIJI}VgjRN8WrfYSNsO$p%~YbzqI`FQLb)kMyxt(;xNl%XAci6ncYh?x70LU7;gWb za{zdW4H;^j&}|wPF?91{oVaFJyDxth-Iie_Q|MFm!85@4N5}d@06w4Bsh$VONL1uUxt|gz|T3Q8A01GQ9d9X z%z1iY=~QJI`iaRBiL;!dBC;X*fdd~7-~QstpC!e=Z#N7pt2>qo1J5`7++%vW|4{RX1C*4kW38aId|UR``aG}i zKW+W9V5cK7yO>RHAsdrAA}xVU48etU9M~^GT*}heEfA_6Rbj6D_Me|~+0lKJnE<>F zSQ*ey%ot4k27pVeL05}vp27u45*-p=5nA289m0VJ!In5b#iVVw7dwf|)X=brJ>E?$ zuoqjkG&MC#9F`MUL8CZY8tfqp|7NAv8c4XxVcmgn?d`WP^Z zV@qq7FJGS6@MSBgC7;STgC zD5kl~e|1)Zk|E2Q)G;Dk#XOGFm6i4gH+t-wgs!zBWAi9kUwB`nHw0h8azMF1e?DpA z#L>EZ)Nfm}e3XH><@|2<9cA6V0ie9U7(; zw3HOuHl_%p^Bd&(CV(T^^;NvU6a>5zp=t)_pD#+lvF-u`bj`xR&vHKA%^@RB_BFvUVkvf2s^)rBmgfoHum$Am) zuN9uB<=-zrkd%_L*zr|RSzT*E&d&m59j)A^zxUmjdN-dj8}|0u@H?b-wK3C1I0qOO zpSPm-G>gEqygyFi_ff(2YU=7cShgGWL+`%>=+r!6=Ft2w7>J1)a7S}-7+V2qfVLB> zPdhm}3P%C`BidWTEc@q`m4+4fB_TgHcaTUfcHBsP)Y`fuebanKK<&95BkjZ79exiL zT~(~Q0ADZ8$p>}>m;+UINwLjUd`6V;z~A#U+SU5ZO#IU%QtnDd28M>V*6@8wwnR1F z{A*g6lX(BBF2BCuj-WCxFxbq)#a6-NCaQo5Px5O#l$ac*qqB*SaMm*T1l4K(69vph zYZSf*o;xq@0v(@vs)EfK=#B93!PLHXW;&XhZse4Xtlp_C&9j0c5SAmc)Vxq}eQ57# zpQ#dXFpb~xX$jg)3U{#>QQ1+!sUv?|R(`(7vGrqpD>1?yfBOsk$Uk?vPRNYu)8C|* zh~38R1}eT=kM>x*uk7df+l$rIOS^UN@yWR_bYg!;m3Ft_PkEJ z24B<)&RcPX`ftx4H?DC4A-g(oAYk7tV|nAt*laZA9@TgR{A#$P1)`XKbmHCo*Po$# z4*Y(QRxbB-QEgLy`D?g#lw}yR#l{zSHp$?qYm8;=@HxK@B2gpSM96t5OupyX^-ar&pWDTA+BGI+VX?8${xmWJ(Ce4POOpl5yv2UI^*3(KD zKWM|K?{k%&U0%0HZ=j*8uZqou$F)uG#sm#B2`sI}kHqa*W{$h%pgjqs zYjp=B-W{Cfht3AaifhQ3>XR|pdj)xQj@tG{UI!d`D-qKU(jLJPJ$&#$+$3M$E~#a| z;7U#wzCBRaLU{bBQKW~yX?7VpbSM~E=JKB67<e!;_9cO%1eR66VsPJ}FDZs83gl)41CK{yc%tATp(8f4LMM6@fwJ!Dm5Xh6pc zaVAFm0MF?Syu4sEdv)IZ`@kBO8-x#{ zdCeWB3_!^7>cSa>D+K75^Ap>^VW`uW{6pa22+qzR^6cV*VlFWmXKk@bD-HX{W4#0#cBZbIYg@ zinJ_5U3|m937gg;=cS#s;lZuiMn;nt|ILmb`seDy%F=-S^4uh~EQ{g=_|Hx6)X9@uar>ZuF_O9d^3#3v@UWhEe5@E} z|0F3gJbWpSHd`h~araKUpkV|81YzkFD$$O+HC7h^llx5Ghghlou(?wjK2Z$@+Yw&u&dzUg9OHJ$3m-8&+`H|U3gx}(QODtn)L^|rwtA{|HRcl9~L@Z_yo zUR#$he|_J(^$W}D`}ZvtvXkIBsbD`&+2i{5UTBP07K9`# z(WQdmD6%^tPcdBJi0bQf0pTe0X*vx43a#!vdbGYo7yQ}Zd*bE8B|_)Vu4E+h>=`>? z+1LD0iN-FrVAD1>BmBLwvc{e}w?)J^)|JIrsR0w>82!qFI8BKZ^v|d_C&3vbc%@V( zEwhb($#X6G^y!n()#LBnePe0zQ!nU6M0hnIsmsH9l#XO74o7XNgZFll%NvxUEYSrk zDk#)``N4*9a~_wHs7qJ|ba-NFYe>kpfbW(-B|)a-`s+#3^-CjT6`Y#z$1-X($>D~I z$(1`cvD_uT#8ac>xx!TcdmH+28gp=Az8a*)8p*IX%wyq!_jN6)?APHX>z#jscERVk zdA?duY3+8XNUV|LF8i?zVRw<*6O+W)R}gT%TMP+Z zFeX)6EQglJLBCW%=^6JBl>2?4!PZ88vI=$SpzV!C2lnJueM*RqeoG~DtkiNM+%o zSVx>>S%LeO88BeTUH=C_Xd-2QR#w)TGgn`K2BYrE!3OiM3$9KjDM$%UDnQWwva*bN zVWTHG?KOB_s%ey%TY5`lhsDB$q( zAqO?ds8O@xN_zC{Iir)VA7)BUE;?Bdx*@@Zrr@1?md*}dZ#N*bcv!cIY<(TAQDOiM zJp_c?ku{b6%eCb>jUlgLBdT=6V`EKZb|2g~gMvHf)PltWj^s58bLO_^EvKi1GfEeI zCu{;;gVMgH6F8=8Ja}x>gGXugNu*7iplVeuoNN+cMe&^0iGqXcP zy$`aM*-a%v(5^Vt?(IElTJ_HGTRP=%QjV0lLDSJw=h@moR^d|lrZ1b!Vu~v`1R9!u zk5I+=gc{Xz^!%w)cNMNeVhs^@r2q8MYHFxY+2iS9WV_^rA3voBt{}$7Amq(%^fuj3 z3g1drq$_kHgDO60v^8jmE+&oE0xE!B>&@G@U3~uL2o4FMXK02X{|U{@Zpsf*=_2lW zC0rErE;n!ArcwOS_B*KMKo8MdQG(EVvZjzRqtD<0ywCQiQ{gWOzY1RZe!4A*=ZKY1 z+OzLc*vz&wm>XGiDD+)maBwP0;*quXyG3tLU&mW;d3Gt5qmA3dnZzRrw!!Q7R~yu) zTc>r*O!vO=Og zV_D+0Yeq&4$!fVj;p*!1pw#1K4t%gs3u3qe=hVB<{q2G5G}c^I4z&q-dYnDlLgn7^ zUT%8+%SnXb(?RfASyuGP_1{R5{V_HDXmQmsAs#R@E61useSm%kd4;kx6;#ufi2?rR zvbP}r*=_;TXoX?Eq&qMls7irGzXlgP`?X2CmqXBUpD=^ZEO`KMYUUbgsi8Y_M zhQdD(Cw;+_ zDA_jo0Rsa1s}9VV7g(;rrWH*lD97Wm1sUh4)JOJEuZa3OTVZd9nMj@n=@D!0~JZQjx6O+7KX}oxVuU_3b z@ytk`yspZJ3kGWR&Rg8vd*ACjMrWVb50aPn{eTn_fla?@V-d4KunBz`#|)xKj1)m3 zAuU&b?}<^z;j?0|`Rvm?$9Z`H{hN<|sZL53gUFfxi0Yv#2jk*O_=UeFPmZ<9uN{16 zOjRc$24>El9h5rE1a&u?+TSw{Q5(&EK2O5PRP~y+UTq3A(La#2hYGjSn}}FmPFD!{ z6?XfEl(D`+1Y|nA)M$OB5q;)JAKD7qM{b(6wKf0CgHXCG>CW?C_h!XfVS~70moRH! z9~Eg(HSe0JU^adbmNq8Y-ia$E54&bc9fw4bqnWvR+il(KvFZi~RxId`nr52CrLsEr z_<=?;P?TkLj1^hl^^HH1WN3cy|LmxK5YB^7<}i4F+dd) zQjqX)JbbyBkT_SBqG(ccbtIRNtp%^QFikgv-`UQ|NV+K7xiBL!FlM|CQ{4ha*g9*U zB!oGALF@d9)x%6oO_?4bws^bB9Utc185DPG8Aj!FJM+3?^A24%pZJd!V3MeVp3^*W z1T1=V>B==bmHK1e18%p*1th=z(mEmV#`&aeR%(4y#IQ&yyUoM1qJuIkPm!s@ycjf7 za}0j_Ycd~^sEyU#()0auUkB9o;XD-8?2(EaU3*RYjHCj&4weT5e|9LLZrzZV1ja); zy0iH2TiRTJu)s_L%>^tgL++^q1q11E0L?489#Vm6_u=xF;DSDvEO0_YR#fhrWI&nUU? z(N=1?bG9%q0oZ+S-?>OOxjRwIWc(N&6&&WDp;iIe#-0!))0R`PBRY z;m$1uI$Dfkd~rrlH!vwKS$K(t_b)TlICv}Y`0*!*9oSzR8$I^jNZMFyJwE5izBPOn zN><)8%>uw4Het*>jH=9{(Gq51yPfF~5!Y)x43S^<32rMO6(QD&Im5akg`87JMhvik z!H+Wwo89Mm{UmCj!c_oc71Q#qkO=he-=BG6HVL1-Mt+gWrO&4S0CuM7XFuc5?1el) z0Fu;207_06gMrdj77|8XQriS8KtV&q1BzRLKOO7cE%uosE+Rh_NFK;*so- z@rIp1WQ(_S^=0$fvxQnns3&Rh#v2;iM_E#k%szimmYfBm=wnHg7A__>^a$1bxIR6) zclVrln#t!d?r`r&cu8THgt7`ClJOzOE{N>7x;z7mjqD>!YiWkIgMpRW zZ(}sUO7n@Ija^(m;y4zG$JkseS7_b3d2_|8RryT?Zw#5X97=#pE14>2%<#wwfQYs| zcU5x7UDsxv=s18W)-(jKzx*IId)()L(Xdm4&^uCrZTj@W$Rv=z4+;*p0VOOdB69)V z0IT8}POMB!k6aFHP|K1;s7k|0m(ss&b~cY+w)tMG%v)(E(1PABs>tylADXGxq^( zCv;B#k$=>aE!5mh@j3_>ZR1~0U(RX-=h=5-EDE5bM`bbJe z;OAyWGCy~;y?AkEdL)A~BUJa^c*)E{>P}erqeoRrM+5%w*af&-@Vl9eP`7E&jxKSoTJ~NiWVU_t7_9_14-Zr##|u}bC#jfF z7Q3X99e@rOoXGRc!Z*PWT4z|tl`Ko5Yp}Q8q5G6Y620nL2hVkdYj!+kR$y!4+Oupmw=FS|?Ma}at6Vk`jmrNWk zR+iZbp)hryI?6}}c$FnT+&hN-5;l*jh%}DPTwVJ_hpesCyPV`#nd%-oT%vR98a1`T~T#XQBTt(LJ!wZ^}Z2yoT9sDX~tylFA8KTVF6?hoi zaX@IhIKVAPWb2qX7-)_gCxWKfe0h9M*I#CYOlsK}-(~LZ7vT+U<9(Fx3=Lt%317oA zW*v!P4?!S)OvQ;o8=SMO9uU4Us;VMt#A9>!tGCXJsvtioxN=9AV?piX1BGN>;|Bvo4a#hK2c=lTVC~cEe!^75wR%Hz-W(@6cw3Nx)5J5 zwf)FVlATz}ZQ*c(OT9Y`Mt2X8%<7Wsra$m7gQk!(@=*_-U$8mUQdO2~&uUo|>UyNZp4j@tPhZ;9#%5@lty`uOI% z!`@5&OG@g*vMc|NU6ndC!%-r~mD+i3!;OB&esg<`fEJOlR>EkLhbvRCA_Z@&N=p4Z zHMcJPd%^l(-wy*)2fG$j9JwviVKDCR$A3x+UTtV}?ni;MYJw++|K8$~_D~#4G@YdKb3*@4uA_7)T(+c3q0C zMpTxH$?{aTGN_;=D6lnhwm*LDAgLz#DF!8q&^a*_Bd{jacacSQaT(%QPS%#l4XAy5 zl$l2Q3pgzM=DZ=xFX9OoNi*X^=Y-wgJX`;D(^hs!0G4eZ%f}~{j;cQ@f%AQo-_Ulx zd$%awI?LXBwu^*Hv^dBGir7bQa`eb-$`;s_f7=cXL{dy!ZYX-U{4CtI=S zifORx)@5yR8iTeF=>ZrJbw9Pk$XZFFyxG`W9-FDkK?N%-D<3{|A9?Eq3>u|1jg)Mo z&%r-UcQu~ManY6*Y;j7w5RAMtSCNx+^S@OY&ZklDN z{dG0W&XWhuJc?jyJV^N&-Ne;Ba`>dRlx88XkBt0Hl&Y_7{MXDesLh@PhtGASs&Stn zxnjk;xnj`^w{2?)3RN}XVY_lUR9vE+szy8+rEv_}PA!Yckk|n1 zf{_OHc&57eq*oiUuI{XfDRShf<~FG(O`ngU1U2A@3>(P{_=y*=Fp9TkSeUuYZmd;} zLkiz)SwowoA9~nXBJgq%LNAw1yg#rbZXMHnuLdn1JaS|nbmWnH!e$h2T~_xMw*!%e zUw(8jnyV_yh$fujP0bfCYOU%hxz7W{_==~La|s{^>FMJwJO6rkAi^=lmMKS+Chqag z@Jp}0wQO*A$F@O%p<=@3o>+>vVL`mLgt|((BDq<{?N2m)EeS(cA|r3qHrc-#x%^!f z!e1V@#n#=M?w_@e!!^fOm4e?di636~5B32WOlazB&dDy9n0uF8Oxmg}u7i|R#J!XO zPVln$`L;}CBHt2CRoN9w+y5=iy;S)US!)1QTmkB5vT|_kk#KY20+^fKum5bueXewx z8lmG_`I8Y$O5@#)VHRY4!S^zMHWVN_2I7?I;GdgJRqMbocpLAb?XO?1lCbG)2^&`> zaQ-4_nxtGMbTB|tJ$vR*Q3}fzp}Elq z2((Z1mrP&WZ*?8<5OxxF{ViMQRtrNX*u+{jdu3DcQVO}J!)yYc*e3>S3=_}=We#RZ zNfl)hS@fiCQ}WY+3wKx6HX&zuvQ~$th=C_MAtN*QIKJ4Fq3c>%j*$ce6A3bh=^fW` zlDDs$9M@f#sTh7EQLfONi&T%F(BdZwg(hEgB)DN^q(xREgGqh@_D7aDBOxU#uPv2H zIDsmOrHr};e!GBf{YMwFClXN-0IOgXBAS#C$+meApP(M7KN^ls1xAFn#kw(bPChIV zKsv>cfsst(R;F0{lsJpdQiK?mxn#^17fSdoI85N5$O$d3$^NuQWAviSX0cZGRj+@y zSLCXRe*p9?u8=tSq@tgf?6ro!Eyuyd#*ga5r(;*%-F&tAC}J_Z zDHBLM7Fdeb2YJnNjEiVcY~zvgcsG^U;J*=ZS9H%bcPx8ewoXs=lyMDAhnk?Q(W-rU_HsC8v3tHhopxTplhFTL=;#Ol%f; zI`r<|jYTtrHG~|FO#6DmS}V7_?jBcSYh~pTPFLfgpdQI@8qh%BOz?a{59aqkF_6rM2y zI2U@>W4Zb(SFiT2-}amz!?SwW-@ktchPb|jC0&T1+A_A=o+<3e2OWOC8k1T+WKczqEre`rXW~Iu^l5lMN{r4VKQ5* zA24djESR(!Q3BYh8_CFa0UaDgq7E@k)Y+S8)|edlBY1#rPilbFGU#CziFI7)$B*JO zUv12I_P|s|42=|lpg@5?5cgrp+)DilY5DVM*sp~ouiSh;F;ImDnHZXK`s;wFp6$6K z*nyFFxsOIk$(H!0lWCLytmwM`ntTSQrknVhD;D8MX`iw>mP8mR5l$+vaVfTnU>xZi z7cB}YsImokyYcexZy=y+Ql8P7kit>qn9!;yr9KO-4tq(}Isj{WRyZYa7v9OUwK|^n zUQocE7_sFD7DuigFfW6aDtRL$)nOOV%xYI*)@}=}NjB6O#iKF2eo`^vaD% z_x_pdZDwLJp!8vtd-t9_L(GPIKA5-r>Dq+&c-Wg-8*(3ZED2uluX8x#Fg!dD_h< zYWpFsn5M1U@pMmAQ$x=!TB6#foVH)}{@gjHE&Anzq3{viFePE`QcA>l>ur@LW0B?^ zF*ig9B`zNm)6v^QQ^zG{s||wq<*bnFW`FJoKPr1m^!w%C+JLuaR&805@3DDvLB~=c z2p2vpu-Bb%&NKuhAIeg(gp04(*pXw$2CP(fMFkok6*Uav7Hay?3akY+VEq;dLQ0(Z zErBY2ExSfaUr#e=i0H>mz4ErsLKeIG#!Ec=dFGOyr1W9Q;Y-cfdI5BWfNrJ=1WncZ zng*5Z^_OhppRPsM1W*pg(QwaA4-FIng35`V^bR9{i;ZQo<&s6&sTI_J!z2Eq;5e=K z%$bgB)}X2lka1qSH%P*JSgDnOmJFFNV#n_tqS5(hs``S^0r!)-ri;7Cu%b zyPNs--)yX{iBG)K3qUowI%k19GcA9$ma{e}a>`0(C zgAK-cZhDc=<0&+A!YZHn=hPDcG-FdmiSQv%9X&?}JsW6_-(C z%$T5D`*FIuWFE~JGA8|sgG4iROw`Z#nj?@vFVTdpp**IVvr70(KmkH9F=mfOiPhjD zBe7B#r^hKeZA2I3cjyqM zm3V5odjXHzbFVLjFc7zMb5k!}GJd6LB*QCN7Vx7huUb$ks}P? z#G#5MwBcL{M-7nCn|JRD))`x5;`^Yks%o1l7R%1sJq=)x*(bkDv%mub$seea2afQo z6(oS#EiWkkjSlRol0Ty@Ydx;edz~wJ&$>c7$q;v7P}&G2O&VS=yKwkt6Le0MZOu3pD7k?mP3k1@+5@&+77c9UWTwFuCH}O?G}fk&a0Vc&_-d z+}HGx#EOY?bXw7}B;!DdP-ee7Z>@Am)8u&T42(OdCNNSg(R5_0UC?1dFq#2y-;G57Va{p)vKD#vX^Wr*&(jWMW|r4~!_Hc+9#q;OfK_a7&d;zr}XLLo7+rTz-S#o7&vbPGZC`y?Ph_?=*za}wtQwDk+ z7+X~7Z(@D=kDsOfK8FN3um-A%uU|L$O+Rwt#DLg+5@@D0am@hR*}>2$07C%P_OjFV zOnLdO1=VCWM!!#?-6xC2n2*Iil$mMQx(%`jx|k#eh$$#iRHShw{RgQ7n@LH83XIz0e?rkiuli z66M~ukne7K6I~wbh=<5HM`1Ad7oBA=YKWsp%{I@FdZBe^-fuk>_1_2?S)4@s0V1#| zgZKcm(;GsM9eZ3+v7LGctQL5LKfHXo^x&-_;*o*vp_sCWwL(>gGlZV6{#(I}m2k{} zc7(G~1Ux(4NBm;-=iPe)C^V^3DN;WeY<&NS{e#X<7!hik#2$r=# zWDZS_ytYJcQFg^z9iq2+WY@2mCqBsJqku>|ouZjN>D=^Wn01wfO~d6W;Rw-(9~pUZ z1lY&P0}GhaH&Qy>ADyEcMwytsYVGM2N7 zcd6=ITp{VE)7F~x?66Z|9@+#PgeA`7B-1%|*7oB6%a3!|QQ$VtCP0+WHEblXQ>y$t3lYx0GnefH5cU_mbD}m-d^XQfv3Z;-okVFtAeiU?vw?@ zWUqXTu7rg!~Km`|KQnTzQ4Y0{R|G@Nn@u8YmHVVPIeihj>4~^6DJNd-W z{lCtraVD;k&tcXn%l>}jWXUIyGFJRg$khLxKa~g+O&#xMOz!Iwm4=jECJ(0_-to5e zfQ?-0h>K%f5B>5Cj}&U|s3FJ$h%M4mk_1pim%MnP%;3-31G%rUJex)}JhT5gWbh+T z1&6II2TY1RT{jm_jY}%KI9b{4z;*!*?R4tjtO-Er!F) zq{DZ2J(Q(ie}Daq+#dh<4m_#%-jt$n`iAlT?bnK~eD41O*!q74&(0V(E)`4}wh4`B z`pvKKA(t*$#T3tWPHexhGNI9`m-Iaw-Vg<;UMsur2GWqR&NwVjRh!g|31xjucCYsx`r5OE<+>bO69v84*3 zg}f)@!6A7tOV78ojIjLmeA_sj5o{e$O~0^>q|>E+l&&lcX-EuJotvU-i-R_r1L z&X{mCK06209hd-&{-9ue+3SP>`)Gqiyy((A3j%FNQWgT!OzA`ebdIz zepp>_P8k1WTeZWkG(J!Z%oLCOe%CG&nMPLH)h{vpEk-KA9eh>f8fA!BC(tx^2j zb#|6Vmg4oR)*<-!6w+aA9JV(hM%zWK2(7gQf)H5X&176|KU3x!exC9sF?e|yY$3VY zma(_9N3`Ez-Q?IF@iZ`q=ElhvdiviB(Iv`?uMNK2Huj=uqCoLP)dG0nN1X)PhaP+0 zJknAzx~Qk35+~%AakE`J?do?hiXQ#;p#gm~s1a#{i3C8&j80h97u^}hXGTU@{kL4^ zuY-=G$f^6%exHpac@J<S1bottCG|%+28BeT*2&_UpN~_s!Mh5o zo~5l{j#i)k=Gv0#_wVsB_#>&IK8}7jzWr@Q+&;S3n(j|x>dM#Hw|6hWY})$zq#7VZ zc@rAoR~>GphB|JX{A`v0g5g|61G2iIb7^q5V`V=Wv0&lCA4G-Z{CFEe_5CeZr?$F! z%a=l3Tr{WK#W)tSKOGdxCeqS(eOq1fNzxMIc4+NOJqPEc2$!wL!xMsXZ=UpJx9)w`vP zLPY$Z`!X&80p%^f=oLACK`E&Jc%FD(-+gJ1XT+>oesu3p3_7clal3n_Z#{U+9lt&N z1GQSxy6B6uf^n#`(!U&grN)@_d-n+74JJebL;o7G3liw=AVJtj48s!nGxM$?^8%JS{&z_#75%BHxde_-Kup{MzWYNF%1gp^&b1 ziV4i{f2c+%1~-?lpFw`;BLfi$#WQGi^8JH5yh;zCPISNLihW zt833ryC_>vY{7%U;u1MWH2C~PJ8CyjUcMQzS1+eeOf2#QQiO|SA_~A(aOabM`23CD ziL^UC@wEl15%>h+0WZ=pR{k5$PRko?N|=A*0*8fX8*%sA-lPTP)lce zOVUgz0dJ0-s~#-UYMR5^ehk_(1=3%HZ1CW!;Os!OfELK5kAgymK+M*Z(vG{9r&%bO zf-{h;^h3n^1*)Kcx_8>3g}&)DlRaprz?rNjskC2?qbK(EG60OjszNx(BBnag&VaPH z{NlO~kdr&!eD4%$1tB)oB@7{jfChjm4Q4s}avb3Ns@p=-D%2 z;Yqb1dt%qMPl&_UFMCZdOin~g3I8aDfdD>hEWCI&Ir$3^5jFyn0U7IF;;7B7M7wmP zZNKDgNLj3tpAttuM1+I6-2n|5X&w?Y*V&_ZU7qUvG9q_liosIbwNW!e=ojcjr!}Oo zLvcd9Vhtf;AT|8u^*x9x_`>MyafL2@IXZTik-o3;gGTl3GL4F}(=I?$dYUj=W<(6d zHMnjI^73dSPd*zVA)uu8PcN3S^*el+s>&qrb{&jJvEf5gMAPv9{avYYQG;u3@-J~F zDKLHDG@j`@qo?>yPX3 zys~=@+RGL%LsXI26f@cxlNYzg@?x5D15;;+{&+B`=29b>H(|X{g9M(9t*O)luX1y{ z0QnPd{qEg6g>IerF5g#|Jl;CvmFCd8qcF-hf;7q+4^Y@0UlQ%fW)g5eF8(&MU-Wknsem_~c7=l}RJkh}Y<=|lv)L@~+s^COk?y~oI6qSD;z|s| zXo25#b@<-)ETg{q+`S2HufNSH(e~Tgf9PDJ-pam?5&I91U6OnAfd7HIo8uEgo`h_f zcQ0>vh~5F+h4uOwPgj@+_F0{Ed-!CxtGAb2-C1)>dbe4?Zuvl|Cas4vBWu33wx(r# z37HuouYG+-PV0-i6Tg&i%h`O^yYZ!G_P3Y#8sRTYENMO@I;%+Ubr#swz@9--sE4#( zHEm@c6lPK3vFX!igU0anl27{Li*mQ{`E095R;RYir=0rXWeEDjghJ$=4Z}Jjrw|nf zc_jGn^Y7k;pavU;LGq^j5V_+Z1;j6+`vFP&g|Gt6Vb|rcx~^b}Jq;TFR+c8YJ@g)u z96($&3eW?O9kusn=(k_!rH4EmS8B?UK6sjg%+W5ctdYeh2s5Y9HYE zLPva}6kFJmgrKc3o*{+S78YZ#&+iDH3eBw(vg?|kX_WAEk?cQvcrj1LbBc4wnZ(i- zFw(2bl76Bfo$DS3&8axv{cKVaY8JM4rXA`q&T;$=2IB!orJ!7bQf3rP>(6(S%o#Ar zUr^Sn+h?A;IT;y zo5(wG>HeUqN50AR?TdbX8yGd}Qf@N-y|OYqdO~J{eJS~ACLTBBr+Y77MMJ~%ZkA-FTmskJ;=h13e+MQ8weCA@?D+9NAmLt&{JVYOSy!BJDc3cTy#?(HNLy8h z!3^XU+#Ia2Emb~95IQQe_lJx%TRqi80Uton7UM1QH;N;Pg2rlp(q-Ink z&|H(1g?DaJ>t}x!t&p6K9(_JO{udbri*0OTP63plVWDzKf(Tlbj3J2j2=+9d6NA13 znh$i3BsE0E;53q`3NgY0)G1Gblo_G+BzFy4CaXZK?$XDPMMN#P7dDsJ;!O-`V7g~0 z31A_aGZr2*Q1S=0oW3fxjEPfev>YhE80M|^e#l?SeMN^jD@})vW$j1hU7=X|%^OM?6d4a=_nhv;Z zFpnijUZigJXTKMV8Mu{l=FFM!@R!tJrkl_V@E#TKzP{~0D2$nPIw7HRRueW-TyyB) zLFVmQ%dY}#K(K5UX1MsNu~X6lhT5P$3m;g- zS#ds~FulvsQn)g7gNcdtT=ivi2Ffy|$6d!HE0AI9IE3LLh3t*^euU?q7eVyJ7vQ&B zM?3l$HaGXu;G{wx^6}#wLbioa5vR&+CSqKk z*|U2Ac@YqL6wyo_?7PzV>Bo1Ux4=v7_1i~SE1}HF(tx_SK#)Fgz9UK~c2419ankTN zsN(e)trES~mA8twhK4JAgw^S<=zt9yZBbzozx|XbXMzWaXd{h~2gxJ;o`m>Ek*)R6 zX$n_n+qR6}pMjT2OKo;y8K97m4%TMaEa4UCyoU{A2GB^v{R@Z$B@N^O(`!+zKumN_ zIflaK;^E4NX66E^b=>yJ%Ad$I?p&)>c+X?M0pp zs9C7C(?g{2GJX4wTf44(@!ycI5da*hJ4m(q(BN}fD{5Q5d)oArPaAR;m{Y)7+GRPj6< zZ+R;$r_fFH<=Cmp?;BNjWyqzfbns_+KP2i7$Z2O+2CMp(+SUR%@e1r|6q^@+&X(CN z2@{E&I7m}V>)~1*w8$c0&- z-mTMbe^J!DxHo}H6;xXSVV)dLI5C%SeoYDB8ORy1o1_0%H}uTw*-tQPm6n&E3J756C72v% z{lzpBnGWz>$B#2wlzR%WhE9=#A0zUX@1K?~JjU}~@18EkSrHz@iER`zXcv*uBKTjZ zv1qdi1A`7s^ZLQuZR)q0a~sK+KpKdTm5t91=9nVv+56-aK(^2D~TS z4oYYrLJBv^*B}1~>}TC|;9^;LJhOOn)bWKM+#V)gXqRi#h*Lp zNYP`u*()j&ZdW*r=>lxybpT5+R5GJ%spA^YJ<}t__~WB{mcS?ns!TIBo`^*d`&dTS zGszOV`tF!;JoiewOL2*{F(wn-8o^=BnaWCZ`4>g5Evki4t}KciXyON*yB^krRP)g1 zg71y~lgN#%EU2$h@rkm5Dm*Cfo0=cIwJsg2sYzG?i%id>Z{X;H2-RnGR7+rtV!AZd z6LrXXK(p`-Z14CVZR0neU3|2(mT{`sh6jFKE4q(4CQP9#=&musB=F)e&EdoMT3d&8 z#`Ohc`BDtAFHt;Sy7lh$2gw~BdMXao*7{Hu-nQ3%aqFaAoq|dmp6BP~Eu^>}Q!!k8 z(%l7QB;2_0@E=qtPHu*d<#n47EgbSPDD91c|D?$tQop8#sepDI&9>1+8rws`l_rYM zUbgl!g_hs+Vx*H67U3n%j0T??1Lq&#;Nm|gCd0aqdX+m+`EYa8MV-bylZcn4(uTc! zshA#mabQn_j~&#eM0+m zmB}(49@n?O?2`Wygjdf^{f{Z(f5`kJdq@~ZzIXb6iMT`h?H&-q76or9`O^GveBuB9 zf|3ie(_+S-%WME7);T$)=jA~Q14Se}{x<|;v1a#_I5t7SiKTkXB!&mZUe$MV?AdJr zqYJzSOJ3r|mnFL|cLrZAe?J)gF4J;cWH(G@97k3qPSAZf*!34U%i;H}A$=RN{-t{R zUiczGA-(P|A7wNv64#SQ=y-k{CQu9Hd;}LEmJ3XcC6nJfk^z!*W z$|4OJ9GGMVD;|g~quMLUw6f{imw1T46Nr{@ce`2RfqW$1T45mK<*u%dDf|Ojcb$Lx z3&Z2CYi=A2%(f7cX_{V#CmZ6-5o3zMzdLf;Vq;Es**vCw+30FUgwI#?&;5d74imlF z?30jHly0&K%X}CsPuCC)!!J^ikr~QPLsrXvhGL5`dWn3*I%a$>XOzFI_xE-dv z>bGZv5QGb1_7ouO@wnO;`*y?9d@d@L(FN3&>@oTK2`oekqM?xpfpGsmkw=8t@re<0 z65PPVm*K;Q4<5|P+{v6Q!xwIP1mtl5BLkMm1#Ar8fB1kP3yOg&&%Q+k$xnaiA=IwVS?3xep&OccC?VCX?52z2}JnARy)R>t``9;=op!>WQ zN>j`zWS!#{y;IsJe}xua;m42Pdg#v{enji|>C>$#iQ8wJY)G)rYNiQuqql{0LNkOV z?%kI$c4?p3jY?}ge3TI#)saI*7W~m8`i(UusfhdHtw>lna4I`S+Q9t`RpA?P;74&{a+@M;e^6y{-h9nsT*hUIp zmlpH(#VgP>tSt;qR+i=tv+L`gFe8iLeRVY;UE_a8Vg;76d(@YyGi1_cSCxyTWl zMnG_IOld7>(3$vy9I%+|7>*Og^hIoCWSNMp45nGIBq`Kc{sx$75mPdwrZi!(=+GfV zbk?18U#l!b-MNMNtqi7vOXT9tgP*n4(QU&F@+GJ z&|xVND>P%E7>u!iT4^Cg4roYs>Gj>0N>rtcJRQ`xFX^+){rC)Bf_-)kW!A;f!N-nqdY0ngCoU$fLE?{*QXc(&@Gz=~ zb}ax)jOs#l`+0QtEPqnJbY1Lb%-AO!C0;`v%vIw`n#I^evSY!=;P0S+PUx(&Wm5|r z(6JB>4X}_H?_*T}zVPM zTp8>G?3P4~GyjRo9K%B@Kx1SPn*e(b6HQD4kN%F!h!rG8ix)Z4zdxr3Vw93z=o}-N zD;XJRn=Zi(v48MUq<#J-MuFp*o573e5?B#gM7RFkE1O?fwU%bkN15t#s*_(>Xeb~q z5ohG-R=BtCF`5lr5jd1#mk4*Dc!GqZ7x~{Y|J$@*)jHdsD8C&wj(HVA3rMk=nESjBRPr76o!vr==iMM_Z zeTF`A9vZ0VXy%{-Z3yr+E$zt)tugkQ3w|WX>W8WmV-Z<|7NN+I3Yms1uu+VQa_SPT zVvVwC_;AYpY?N*!@=>Kst=AqgVl@;ZPSHb$4jjdonxAuw)r;94n<4`G!OqctKHPBr}f++dzZb`wmLeCm5A{(QGL zd_MVZH(olrx(2J595{tMns!B~-IA8?y9YarTBUJomcXVt*yV3z!8PiF?}-= zFCOJ-jgr(SGMknzaN_L90K5(lJ$*F$H7Uz)W_7X6+>%5`&r+r72clJ$QPt336POom zCKSIfUy@oukSh(rj`y)K08&U3hnZkrCT{Ck)sml=pk`RBR0aq8Yk;re4jHA;Q((V@ zlRzMFR3BE*{m90Tr;UB#hI){^ijK}E+)8W$+t6abhL=Bo4kG51{PZ-a%}R8Pazt4| zFq!XYID+bjsREqtR|w1%&nH?5zv;`F%yw-Fqlmyle#$xwmVCymiZXj2s1XVa$~ILY zfue&azBYACPMNXAPg2M<0m!+bFale$Dw&Nzauz}&^QUK4Ys=I1lL+kdhZVXj4)%+o z57z2TYPd5L%RNgP=1)L6Z|}kk`=coM;e};v{2MZI>w!#*9zVvBZ-q1lbh^VJ>F*C` zY1CC@q;lJ7I+rDR{U~s-NAJXBz7F^qouD|vTRKqQu-29)Ex^baYsFe6kCd4?>(;En zlm{Dz2EoNq8btIVP(K4&gd1myy?teYM$S(lCiYgG-hllmO$(2YUA`jQS#l9}*2<)= zeX8|a@KgK;uoj$id>)1l+`CH&C#kJ887PfVy64t`l`~1OK~CeStv2SN>w$tQ+myQbL-%AI3t3D# zDTYXr>1MyGkCDFKbbq49x92hG5<#of&WxqTQ%7-(`Z?*|SgW&d{fWNqCHi$oS{3vu zz%Jytg`6T2P|WgjHd$(6Wu;A>RP^bbar;FqTFbYhrWa0{6$#@`GAd+WMthmMbDNZ3 zYWu08`QDMqF==bsUtX$pFIBUC`8cVx)Quhgw)gzyIgk{}YV=|9WTnP0d2~ z#$t*1q)g{(#lrN(TUM7gwyQ-*M z&;Y@NwN1q&|NedRS+n4$r^WyPb5Vf{iNPQQ{DD7%<$%I*@64J7DRynWi5N7G2#r)U zp0Vn|l;|kt?GtqwL^Nj1?6{InJ=3vi!qYwY4R0plD5g<#ghh_plwinNdz7~+UN>#u z&2~bktnq9xF#?Lk}1WsP{&mSBNGM*)Y^&>bHjrfRAZ>f-~=3Uc5JWS zf|^9Z;;XTB6^D`>{* zs;lo56mV=;Qg7BjOLd>Hz^VO^a9a%<5r%D4SjrI4oWfSPUwof+2JI3lv>`<t$v22&SE@*E!UU$pdWPpcA5+3-a5&%gV|MPaCGX zku3Emo{MF6i}Yq9Q2|;Z^3uEN>QDLD{B^QPpqF9T!?k8{C3{=9;Mj$Y#7rD%8A}qY zCQ`YZH;seTevX%yyvM8YmS3MjOk?tU@EK$kFX)3_85ON4if`S%Ed;WZO?)T{paAmM?7I=3smC$F5c!H1 zr|^~`V4OH?Sk=7Briu-vQ{%`#lJ7IVZ1+71`Bst0 zD=s`Nzwmhp70IN!+S)!H$hd-9S2$)vmB6qhuaKZ1p;)-T(UG`UlmtN)-jHJ85(RliFwX9kVdYyUmh>+EUs$&g?>Te`PzS~J0ha14^feIGJ zKW5-0G*|M|>1$uY6WeH6!1rD0r9Jal6>31(BCQiH8r}3gsvmIm8KBUM$0)Wo3(MFB zqBtlK*b!MH8C%;yDrM>Gv=+t|RJ(Y-!(TI2gvwgUKZc-MM{O<gXu{psXs2P&agRT_w5|1 z@T8~Fjq`i#Dit@8szTMC6C<) z|44!5raJ^G*pt*r%J-9@>?Z9=fx=k%eIi>MWU->09L2i`UEqu&?h_uEBdKmv3+E$S z9t(^o_rs=)oThl!pxt6{3nGk*(rc=JlSN_5%@NsX)$(W9ki%XL@)QZR4c5!`7w zxG9ur(^#P|+j_?Vcnt8n%@1;Q_yhY?h?B~u)%c4tDk>j6I!bi5ekDEZvuFArE*lC8 zm{=_(tN4algkv@Qyty>BrsKQahrbKCUSGSMn7f~Ucwrb1R9Sc|=*PAGDC1POE-z!T$cO46WKJu6LD2j)8h*zc2gKI{$5G; zpHR3m#-N^w67)K|I&YnLnp&rg-BS!+MX<2<^!Tx3C9&b`j%Q{&iOru{1FyTd!iA>| z7A2O30YlKGU|JW~bEAG}S)KJNCWgKui`c$836@SoZwHm~tf0@GL?2$gTNv62^#7pv zW0lK|{n0%+EYJn=l+rfQf3cIH|E6fTIC{jeVY88JEZnvc5iZMxx`)LTFxouelVIZp}K}f{|-wzj>1RP@*wl;Xr-adO{0AM{TW<5SHBQ65^SOaj3& zpm?W-1wVcEjL3&5*<-9Fk5sgajvW&>qCG9G2JkU2 z#-mW1y!yO_wED-i9qJW6hL$4Zfisxj1B~s zaq1b3&&SddC}bWP-b9TLOR4$zkuaKM+jt6)8=I!)lzHe^fIWx;r>sn2(4fR_n_%KE zT>^{(U+*M(7cuh%+?rg2kiy3xR#c;*L!Gy-J=1kD1R|4_ zF{7uYEm)XAmSz$$&a`Q`M>)zpEJctS)sVa&SvP@a_wScpIjw}Pj-H>{!g(tIXfK!5 z;fqC_<$!evcHe_MCkI>cg_DG*na&1~g*Qx&PUQJ>&J5CE66ELt`~(D{v1x zbmrD)V=yNJZ&fnKOX&Cw(J3iS+@Ed*Jtd&y&Z5y1CbWob0NTr9mYrkDGdKnyD{cAt zYC*Q~tskj7{9L_;4jHoS!G?Fn7IrDlA~;);QwHS$bcEeP+x&rZ%JY`WZ+KJ|ioId) z5Z`4G6*<2Mgs+}43F`q~cCajlJFL(e>D>TN&92|QWs4EVL(&37VxU_WzV$vB#&^GE~`aD=U@A z6ga;OnSncP1uQ9BrW?KkQWZ?h%^j#iZ!KfbVv}U9A=6z?=B?;*T7*fUW>N4sz~KN+ zf+PYbgmHQ?9Tc`@+|H92eVyMshS3}p?95F8Am)W#E3Tvv*g#Z?Tof%XiR_L=Lnyeu z{Gfayj~TcPYprH=G{#Ln0(neA9bh2~Yz}jdN)rtTrvjO@h}+bpuj9ucetBd;10V}% zycl&%DZ#c-|14#e>@6`*)pRKU838YeOcG)ft7sl86*hiquyN8%;ZXfznWMq)N3%4kK&IV8eG7fu$_UYAYZ1q#0aGQab|D=M@y}Mv;%51(FBJ55QmvLo~optd=Ln9?wQ;DAjo-sON z)WnH%p+R!SwR$d4$WhHPXt5Gu^=A)(&EnUJClvD1yU}U9y26?5X2XSsfiFz2usmOEN9?%1&`xq+xkv=`9je(4IT<4) zwPjnyY0-HS%-8a|hn}9?UvwM#2$=`s8*1*HrPF-##-eT%HEDL%&)N8dsvb)@Iq-q0 z!1DR=+v>H@uXrY~V6aLuiLM_Xu?50^^_M$b<;;^YE6lgJiAHn%Zjo!Mi_FC`h9q&hbnD^EqM;=c^=*WMF4=u(B8jy4*VR`S$h8-C~9ySr(*WvhVyK z0bq^W65|hi^^}7fGR_|xhp`-}SohcAsEuqY{6HuNSKrqGm_^Mx%+wCvG0k`QD-p`h zxKfq(-|t8Yo8TmMckc=jR?p7?7WcC)u&N|z+jtGdcd`&@qsRBV9O}$l}P5c}S9CGZ;(*57)tp_>!xP_N52XsZN2GW@Rf($bH zynUg!cTib~&x#e8JD~HyRaH6~9WrMDU7tasopr^dzkDW+RFz$OBZLh9>bDc%hwq_n zjdm|mbY||kzwMNk^noB-;$TJ$gWY{9si_>zsW68%gncHt@QA3e3BVOey(O07K#LaX zT3ST8P$*BlJ%nivP&`Nyx&&aP{gX-CAus|!2oDP*{govhMxeQwY}m3w9i}ojgAu@@ zaba*`N{^z^eP+0Qrlpk?i;jhis)CRS2Nc`Noft=1A!99dP@a}rKys)d>M zN>k;L8;;V}ZrSaIdON)z54P5j5ak4UcUIr5koQ&=q9m0%dWgc<^`p)qPAq?Re5whbm z<3G}`6GnlBW+9-)Y$E>rFA=l8HFprP7>vA7-r*iieF64-MoU9`6m9A|1rZ`%DQg{% z%tW~L?13`x6wzmJ@*Ig(YK+(nE`{$GGOe_WQAi1xY`g$4+v6?lv(RcNm_!}+v{Ec> z_}KV-ClW{4i(`_r6jCyLdIl(X7Mgps|6iAwM!fXtAdGg1o(0tVIfMNe<)KElrcaf+ z(KR=_2&q^S%)3eKH5kIFvOBq?czdj{F}s67fy}Ve?ee>SkKW)kp0NQXO_5)aB1m&| zA_0EmMvX#jnPK>Z!rL+#HI)ESIu6b&Xl!Y0Wb73OQXL848zt{QeY){8Z{;+@t?g|_xRdIm64a>`M|dc z^UQ_7|Nn$uUh#RDuWmZ~hGPc;;gct=8Dy3AxJw=55NZ$kFnaFHgn>RTY;LShTC;=CG^5uE`b|!rw*z*>!O0KOX2(Id#Kw51vyl zHZ|H9a?!7_dcK!ZX*_7C?q4DaE-0>8b6eo!i8p1p_Z758VjC})QaMm+zyyeC?Rp+} zUe#INP^$Mcpfd47cqBj^gf&cL(uxnDlpjBP>^=yp>y16>cYduj!RC45#GB9m;J=C? z;Sco+!^L2$YE#pHRj>W%PY9g+{9#G*s*i@(pp+5tj_!>{qON(|zVv0cjm)u0EFFmn zZf0&a!Ii4T!}5;3zSi{jnbd6&LbQ@0m!_L#Mnu0biV;tVak#Ae({`ZJMK~~~+&Z1m z`)IA1@ohb7y;6xHZ1XP@43Dp}l1PSR_;|Pm49)P)7dLOs65gPHd*nDeByEQP#4Ec1 z#4o+8$_oj9__ZQws`yg24KTWVd&LBAWYq~ylLRrFm@pQoHRVP@ZHBr(NUO)DUryuc zW|4C~k&C<3yDaI97F!Uj?O8wc+?|3GdQ`tRk{@EH)r9$+m@l4ZXDS>xE}TuD-PRE4Q!OrSogy0K z{3J8w%JBIeI(O+QakX{NV?gfgk9_S%4fd5k5|=KWPvYf!wMSNs*yvU7@LR|OYjOFd zrM?Yn^1V0PMdA(0rKje{IR=?*HkLd+Vm z6qwSQ@h6Fr5>r*XsJmX}U7MwTTcw@$es`KRwl8p-cvQ@;N*1+yH-4WdktXNZWU0d* z4u+4bvax&7{f)`4T%w;%QKG35(}1NqW1NIL$-%a$qZ=kovn literal 0 HcmV?d00001 diff --git a/graphcast/docs/provision_tpu.png b/graphcast/docs/provision_tpu.png new file mode 100644 index 0000000000000000000000000000000000000000..cc825c4098eabc62e2c541b7045f33aab3be6a78 GIT binary patch literal 40113 zcmb@ucRZDG{5O0VA(TUs$d+ufBkN@EkiGZbqwKOm6ha6|L{>Igm6e^nC1tNNv!BoT z-Ov5{J@>Qj-~GIv^9Rv!opY}1`~7_0>vKhEsL5XyE_1NI%3#GjMG;(lOVnbqd;Zy&{!xrMgH8Rutw`HomAk!KEljkjnUOVpD!DU!}+ ztYQ?+yMI!gI2aD>UMPv?{AU%`ABG_Gh@y>|bcW_*E%_?ENp}XoA1Q z?G)&zw!GCS^xe?1aMb+uXIx>ZelsVAu!ne-L=3I`Po{LQUc08Nqm#hgSJtf7@nqwhuaJOCAZK7GsAL@F>hHB5wB$pJ3H1EUNN6_y=o`@EQ!F`M@t#| zn~knNnhp;SO-xKOG_=A{*daoeG4uoo0WmQ#0Ra*+K0a>Uo%}#s+wF&gK-%x$zeDj! zCMPChsRY?YMbivTO-+NBzh}3$$^-|WM_suU5gB;}_x^(iRd7_yk6~eml)nDNK#qL> z>pMd=u0Q_%>hSmX$3aww3JMCE&(Hkt-D`Pu+fwcsrMa!`fm9z2L#wjydV;Hi!{Cb7 zrAwFW?d=g16;)U9Yd-5P&wlZx#l_^56ybn>e;WM{utU_|uMOsQN74mRQc?N+{2KoB z>C@BW101Qqz|*&~F%Aw6$lPR=qlAP6UGPa+ef{z4JDvdn0ZK|rh}p09*n1Jg_X|fh zH#ZY+smDEg7Qky8hnnZu7m_?b;YWyRM7`!)ql{8{nokDp1j;%6R|a#-EIZmeIx08~ zb#xY2SE+8?=qY!Slfy%t#y>u`o`{Hu`1b8v&5xjT>xrGEew@z}6}I}xf#;_OSZ^8{ zURzJV>tKJDJU>)-If+`z^e?HaOSPUDDb&UJytg*yT6z^t;J!&|xHeX*uAng9=$8)Pu%IATNMvM3rno=1NrN3*XWM9<=hEux>e24X z6$%QQ!`1bPin_Ww&hEZ+VG<;n+c+&PttgVStEiD}wC_ZMA*uBu!*5s;_3&1l;r0>wVpUSIvRSfj`R6V&%7$|*0i(|znHOdab?`L#6x~BEX*u@lq2{%PGe{YANkImJ8-(F zg`AKO=}AnGaS9jp z@~!6crU$98S0*P7lC2#b9BSpRAYNOujj%P5H#&vAHaqieF@`P)nVEg%PBu0y$ZLtf zG;94T198VuPUXipm6PqN4EA*n8rp7U1#C-sQZuv)k zdo4<4I<39EeSY3zGYUV#eAuDt{rmSajJH!&pd#ev=JKWXOifLVj9f(os8k<3NGdEW zG~&ib8cj@fbaYHiHh=y4wXkrLBBm}-M^|@jWaPnv2Ny0}c=YJe#qh6}-}9oE%f{rO zE`84-mJgBoFH|l{Ljs1(|5}mz|Mb*nRb^G~zuMouV@AH)3BWaFWn-)Bz&StJ>6gp- zOY8dt<%g|LZRo$h9u*hoVrR!o6?=X(*8HKc5VJ|BiSbH82JEaT6*RNg)-zDNH4MB!iIaoe(8fjJ>7 z_)Nc8r+If!*|qt!TATfJcSzl~Ka(q7VgA=ItU1_fw6wHjQNO>%GfkJUMo%^epQrNM zCL|c}FZfnmTzmo1)zvM~{$xMIA|@uL zuHJwA$Oksti#NShb&tlPqm?_J(VFX9@>0UeP3F`u_gWeJ@PQq(jue!XJlx&?Zg2BZ z5_5L@J$eLN1S_Pp)YaM9Sy55(!GjT~TlenWV~bw>(JZMtgh-)UhYPe|)pvDw8`656 zKbDkX95}tJ56dL%!egnqrKO>p%(cc-RiuVuzFTV(;=2>^>NQ3MBM$ayw9s&%N7Xn@ zpzjZ-mHg0Ev^*N}$BMnhVQ`Px_enS$u2V?8!IBaVmiWa^cKuj_M?c}!WI`r+^w|snj+dDgv#MjBFs8;9ZzSp>#;^Rj` zwIXGb<0WQF>}-EdM|b`DBj{<}DLnXics^^R#ae1t8FnAiD-Mpd$+~{$NMg6Oy~R<( zsF&a6G~@`o*zwDVae7YK&Gz{lgq|4D-Kmk(a$M)(AlNsu@WN0 zcsMP*ngzQi>ZA<2B7Et+x72Pvr@B4ewO!0ZX3rn!Z-a%T$lTzrK;BsEzRV3-z)8_Z z=qX}~i!Py;FEXPDdTc887F_ES$K|YGlLn4v{{E8U|KDLnYIW8Wnjrf!2M-Sq8yj51 z=bct*eAY_+nc{V~M8%YsI&@T3W2NZ-O&J($@-v3f-g>W=K_zh#e1Hp`_!t)F>Azpk zVIgODKkaK)P=U3epinzh!;;9x#pT=9|PGBEJ_)qd#-eJwuIGdatk zprFFrv`Wg#73Jkr)YO7hs<$$*(aZP$ekI~TKEn~=73&w9CS&!eIg)z!~or?#}T@bmNk z2tE%?%S++Ze>-3YFR5E>Fp{f$%|-I@bIKGGc(U%AGfu)>u77Ul1I~u`yo>G zv|RTH{$dLg(;>s7qqGtM|B7p74I2IS;J{psa>cWqE4+@DG1mI{`1pvMr>9ooh+NKX zUS1U9DlhK5FW z8&sMtxIp{s6Ul1Vij7$!q2V0!^*PbE!|)H1I0IOcrHQx}*cga0PJZJMqD|(^LI< zdUZ5&0<&Q*n+*k8Jx$l-ku_9uSa-z*1=;qZjEwWSDJJuxO=3h%)1jzUD>%RPG`?AL zg~R`2^K*$Q?t2t{l_6|0N4#@Lw0eF0cb6fLGE>-h91UI5tvNH-z{An9kz04yKWF0x z^j~M#{mh)Dx;9c#zqep0=|YIy%oXM~>iX;EnKU&?^$hRz6&waHQ&gBvVPR3pbs>?i ziOw>bsj5S@>H(|4wOuKsn^z>21idW5pnztN*IVSR^eOD4Qo?2kqa_YNuz^_oXU~O3 zx;N$8Oe5kJa{T(m3%9@XZT@#wp)!lPQ$1HR7oduXj;8hDLZO5ECrRnyA}sc%AMdUJ zTF4z*sg*th^awkY=EjZnjg9p5bhz@c$lcuN70b)Y20A-CM@K0TDk>@fBmghGKO`tDJOzb(YRVKg41i?V(>`8K9kF_0-$UiOV4k z%`GUX$)r?WMTHn@`NfNOd^e4t7zj@a$lk}5UiK1AVD{f#CZ(lyxo-M`{wWZWqr*c| zTz4QxVPRoUpAxvbZh5bbUZRtj?(SZj@%?d=;OoT1qpx8X5fnAGnE6-SoPk=0kq<$q zM{TXG`6CX+I_WQ8?rwMUkWliPP1m|xTU*0&YHDiIEj6JLa-u|Ti2Ff($-;Pn92{I+s;uOIb~mT%v^XjJ{rvcd z@1bxHlwK$>&$nAs#I)H|OxOtw3e1d;i@-I5b=Ey=m$o@tY^b8r16>m?pJ9_)tGE&i zS(Hi*6hW0Ni3~@4SezIv;v_UAu(7=St(;RD#n}y<*fi)cx2|qtW+tOsy2M=`g)5Cx z%GKmf0fKf%OzioMI|^3Qynh#Zr^8eOx@+ipS|B27wb?h-i)r7I@f_1||23OcLkmoPLj zDbXo()p^QXYFw`aP!@{Xix)56XZk7S4D7B?K!?yI&mD27voI)57K>JUZ>gVbt;I)hwcF+1Kv!R6(3yI zLpc7!hi?ABGYP{PnVH|ceak2;?7O=>kX-8F=?PbshlftwuQDY?H-+0+oBaVC&&};4 zI@`_1XMca6MNuxD2~?e!opNCz=9}0U=^z zW7GP`qR`Lt<+-bYLqJ)IjqPnIAwVya?tN>c0P;+OX#xLN1?N8t?f+El`Ol}S)LRw; zU#+V<VT8K9a(|<6B$i{>AkQ8CIPW74PRym za~bk+EF4imIaC?sYP#bPcENi!>QjB@!1f*btTsDV2t5H@%yOK-=TW27Q=@ zt-V-v?W0d30BqjFc?u0}0kwh%L1**czCF<1&ceyb$-_f}XliMNhKAD6(2&qcsGx8a z6POi+1O>T_YQGN*xX(5`T3A>hS+qi-)t@w{&L-vACn+UmW$}<+pCV2ZRG}x9QJUB9ir&uqbM(kE@}af4Mmv z&m=E#yj}s{x~$5-enlrKL6aOUF;+&&Kn)2#+o7YQ%c2YVmks=1+;8_E!0IXkXCPX_ z=H37VDLC~$0wCF(sh6CqKof|!PrhNy+@qAy>&`EKW<9>j z5r*P9$*pRe21aLj($Q1SsXD~jT}{jrhBBRg#Mp_UV}V$Y%LtkF!mx~ zz5iNvN9Xw@m6476m5`TQ}auU%68O7W)$-h;iOvhg8=~&_mw1j2W zQ<9z|DTH49?VqZr&0XU6n=kH}wjQ~~G-o)7qz|d^|9Kg?d}t@{8~(57!bm3h)fVjD zk3l7G3z#y_>xx%9E;RYxiBG;2_cI)2>h)WKO7?N=Q>1q4$D^-x?Q|psea4^ejEtvt z_quqjwwYh1zARicWsHST72jEi-tV=|Nb~L~8WHaLa4FfDnkrUzTd>&Y-&)yGQsjda zBG*2#bgehu$z)q`y7^U$;cH!0c-7T?^n62Gsu~SqeUG+L_z59d)vRJ}K>p>82b-WXC4i)TngltubB5v(YmgbhK0~y{Nr=Qo22^GFK(->`_l=vXf*Tm#<%ZMQ}&id zxArP#JZd+Cy&0vXPH$duq*BctX>a>=Baz>gg;ZK9B!zjr-)7KFHfqp`mk$@~G3{G= z({%V#Dc1r8g^+buDvj~Na1`BX@7*Yhj=?0Rb8nn=MFy1 zHoSSd9=GE{zx%LUsQSi6xMYtPsb*LRV?udaSNf}i_Q@-_uX@%)GJT2(w?t;==2^e> zjqSRTGg}v!1n%=RYK!s{ga_My5OiEAbMY9O@po(#ybr@MFz1lfcm==8Funh?H8kP}OupkLu)iyEh7GPDaG6 zzo303^mlLp#a?s2bya)lSZru!LgqJZ)lEi;%o3{EvBwuQLnwc9h#-r za!FClUre}vQ>;Kb&)zr&gDw7 zxiNWyPSwbY7hb3|OX5Xb?9T`OsQKq{tvZUiCdIk+_+1ZAQ{uqDQTP7+sk8sy;L0AD zLLBjb#>(so|*R zlU&=#qfwto2sAah^+ONj+Ymtu>2+9y- z3w`9QKcD+~UtPZ2N=V0L&`)Re<#6SualMGB$aLf?Lzo7&!`kXj=`}go{fhE{HSR-} z2UkD{SeEv}uKLMZa|5g`=4fgRPuekW>_CmK!#UU-pqlK=k z#X9~e*dW}(MeZV zi9&z-I;s1ZHJVDm?nPFXDTtyNHUzegrKOgh-m8}{f3K}k)6l5i*H%+g1DLF!pa5PG zYxF1gUveHE<$9cPQBebHdrN&8Lk?BohzV20lFMd@`KmFKS=-v;^?WKe_J`(KyS;^X7f)7O@kgq_BDftS60?PP1KtEWdvL$k54;Qb@$O4xU%Vwylh z0act35Q*Sb!JkuA?X7m60z2s9FK$x3XIH`UDP|$Dg?`z!<>V z8grt@aDlm5NubMmnkVISmG6`U4QiPs=z1o z^E=AFp92~eh=9&~sX(nAEG@54Q5^xBp01$-AfLacka8D8%51ExKv6q62|z@GPHJOf zltD}a2JucYk&%fB#5_rdVOCSqAHdqxGencjU%AYD+FKct=cGtvRz%^p@B-V%c!Dgd zO6uwEp+b0Opi)2mQA!K*7@% zv%q+$U~A>)f9vRwe?L@LUk_rkoUCkby0AJYMey!`JeZPTXMF4I%w~eK8c{2zMH8IL!U>nlxp?W4s+t-`EzW(-@%Zd))OEJ~&AMUG#K^58`R6Xi?*BBv zDga;)|1DUx#}q+Q6;xQglu+x__$I$rhLg;`(!gbrC=7&6PJ-|!A;|gM=MW#~NUb?~ zpc6ber2kK;h5whvP>2^}MwwSo(78?ZLYqBFftg0+)fY2*MNL+6A2Y)+)V$tR!@a#d zj4LmfLrNH)l#~SWFP)f=lmBxo0jeuxWOv2ILAG~yx)`PrslUfw)oI1f#dTAWPf*Z% zb+|xZU;pRNpW)%*&_+S-1Mz`}n;S2rt4j%%1_=oXl3!3z=eOqo4iNNAM5=Vmb++LU zOf2x3K_i2L#CMtVF=~~v&_Eo>FDvV}k&(T^(k76@b%6+?(P#~g@9?boxj7AO?aD_x zizbbJ8?%j5LqiqS)%#%PGRZv?ahsEB4mxG$;J}h%j3=hOE5glPQd#K*)$f%)_%M#g zjQ0zN3PyTeO1P5+1yo(9s<}8holgG!MR>tbZ*6VOz86Uq+kWMi8UgMT%zDfattu-k z%g<*8wULhxV*nQF6nFLaFHM??W^hYLWICqk@9S$B%3Fjn?7m}E9Mo~)aJDslL|yj4 z$m7>wl941`X>wUoVj@4#TJVTV_?sFUrl+P1naLjc`K@}1CR?{zl*urb)zplwc)>}B zQUNZ{kAR&^zVQSIln7UE`Uvxv>{r82jR?e33DAWj$?sAQ)`q5rT4s(m3_5f?D_8`@ zcapBLoO)aDDDyN5KzRXyv!|!$^XJdt$UspD0I~1acj#Is)H%W>2th~wyiSMbnoB)d$^Q!ur~JGT3=F95-Lq% zCpkPwNq?@brKhG|Cnq1yktcclhQ`waj9vkkl4#|W(tn3XdiqU!8Lju^$vWw8m;A0PrWqa>H`lK&D~nIJyw6DyBlC7~(Z;0_x?TsZIzU#o$;7)EK8JFs zu-tYMO0zr14Enn(?s9GSSZtN13%I8*LWWS)>VW`rP(b1rP zKO4`9Sa6!6?4hJr9@sF*hV;4omC>0iEB_+Ru8Q-xYeM`Pn(D>tM#FAgX4md_2+H@Vdh-sO{ zg@Xhk^fK+j(o*xk`Iyqi{YmPYhV3qnm}@+Xndar0ngu5vl4n>5cz7U}5fT%>7O>Ye zGcyDA8fgIGh9$8;i`~iDDMKRA;QbJm6lmKzg^dsYS{NBke|S&`61as0BjR>-@}qeu zq*@?+Q4ZoYoJiOTOG`@_^N_O}k}RotEBv;7uX=3|cpE%^2b-s(hOQVrtiQkCNoZ<( z91M1uoB>F6fGz`V^-1la0B3hOOJWp89J@+KcMdBwGCchBPYX%L_7(cwFrteW59dNj z06A4JB!laqE^*j8I$CXftf;28C;`x^@$6^~fu~EpOiqRt?Bfz%91D4aR&gvI6oFpy z-oXKsJa8=mQvC?Kcm*teMzSc#$Ycf`t@MLB)s@5n{$(9H`xU(NwbJIz_4T_JtxuuY zpl}%|Vz|@?#qQi;BPRs?8v%d{AK2H|cMJ4yRaMo(9Md|FoAmT=lvB7tB^%096@#h^ z*{uicNngyC8?vyoZ;gdu;nWtgA`cf+Z6TLKUvqRxnn+lMnM4)sE6z0Wza*e5r+Tc zCPioD0#t)s0|YL>>NEA9(!W8M)!d++4`!C5oEkVJVvc&2jt6|89Kt04P+Y24_Jp|; z@&X65{?8qK+^>Dc82E=j&d+?Os-4fy&WKy=1k%7LDFu_cvU0bBF8H;ObD9FPsQ0S2 zrlw{~2~ZNlr-Ew$ z@D%JL_zF2Ei0^<7k5S;m^1-)JQUV*;7&gOX6$b*#2(}Km#Rn@xk@QaieiRjPLZ~8v z8K0QaOn?g7BSugLy^fqPFfb4-K+1@VgoN18NV_0q1aGdYs1y9_o$YOq##LA_+#@zt z_5BG7wA%OD>>%^jzQ-1Fp48UT+JOui&Sx;*)L)CZr>wYczP#%U@(WNIDu*iAF*BeB z&r`H~{R%6U?8+6UB&DtL2lWZ~7`@V)rOrqKS|WsP+-4*+3o1wGi4HD9G;;WGdS>eO zHxGYt_AO91UnKoUh|<&0$-D2D8HO%z_@S!2o$g~tw_|}GP4Z&EitOUlK>w_0^%EQ% z9+OYg1rHz+QD65uk(!*+4r9MwC0`CcRGK|pR~}ykJ>L~rct6pE(gQ8CkgPq!s-WqY zT6)?6jUN@5$sH;U0=7uNg8fKOb}=**`}8zms>*l@6acW+{?7*Me+3Qx&!bWi2W8fr;hh#3`Ipq`_*XKk6jJPuxF;KEB$1TtDxYGuOetN=r`%QxFFSCs%U? znxgW3a>WDzRbc@E5fKprOvDRapuPI@=YU5);33E}^!KahYL;=X&CPu-E`>*N@dJo{ zgugXzra|W$LlSFdditH_O5^6!MtKnT4g=U-2e^LUx1 z;f6UFn#Dd-!^4$sKheq;V5hViho$u{<|?OnaEObG1JKn=4g?c}>t+XBs&Naxsw56w zkH}ZTu9=-yqQPf(5y(gc0B|yT49T-(v9~&ffe?17OWFZ{{_dl4DS9*gQ7gVG;syA- z)xD$o6M8_Nzqs3M`ZJvcsJ7g7X`+gMbv*Zs0aXGl66_GDkr2BwylyCoNx801Ru$&w z+ksb#yoitgQjTdGSY~i=Xu&vm#0aJmfb|dH6v0G|fpx;8_FK&J=cR>(mXvr11ViEa zmzS5l`{ba`39DG&@VSi5WelIK0)-4G3i9HX;yaZ*{8RVNwu3WTnlobE!! zy?>w7zi_5C>F1{s=px
%e!h==u2*kE4}qpvom%+ox4)C&UzEPIcR8sVNBpiDL{D zfww3?b*Z}?qLrX`7&*D8)(4*lC0pMeFP9V9(=xmZuA} z;!726Xc-MO9DNF!f`S6rFaU!x;M4$d05M_+V6ct(elhJ?=y5LUc<7s5yw+_C4BWvpIQSf*2}}?msppBddm+q1GHF#-k3pmoHyV7j``ZYYR-SMyH8# z=vNuy{yxBdHcrG*XlGa}t6*DKRP3w_y+1uY1xNBJoXa(@PvFEE(4zt%2thOM43-6V zO3kO!j$$u&_d5_%2U1&D$bqA}HZCxg55wvy5%WEpz2OSki@}Z74p~0 z%F5^eevyP+=br^Un3jz`-hQBxQF4IW_K-L_h!*C2FWO;cRup=GBi`JF^ zN6fQ>T#ZTfKiXH~PC3~6kyRf-16F9BYWg#|U*M2tL9ZvK7JA0P4+sogeW){RT6yns za{-Mx1A*K)y9ikc3ipGOw{Zh2Uf7T4T*an_9riY^r6b^QJ2^V~AADs``Y^JvWK0Lu{f}$m zsMV*p7^@-h1iex{^RB+BDIE&M3sH7Z)FO@9!Zh&3Nz1GuAmtiNj{~Npx(-Il!D*ml zK>-K`5jdKtOlTV^!j$0K&30 zUUrWqG56iOo^;{TckisuP7XmzKs0o9g`sA~%C$KA!x{l$Lla_IUZNmT8srSTQekN{ z=~o@@x@t)&m#Yau7ncDmzFY{Fzkd(KTvSl7YNsC(>4rL0)qsyN)h#XU2=H%fdpk&# z01_Yr)%)u9uL;|%4Oj%A^A^U}iNH#MeNYG?txhY5K0^;5D>a2&(GAc|0QV;;0jvNt z4chJ*Ku1Vn9ZWiw_NMYjQs0Ajp$ajl&mmaQtvX2MNzoJ=%TpGV?M_0TzcR>hL8Zp< zdKj6})m_@!ax>+YXF@ONiFH#RWVi9InrX(ldh5R>FtX!0g#f0HXZ25Xj3LDl)lq42@P-*uDvG1Uy3;WU0)o~i#v6%fZTL&&9tUo-jE(?d)`(%;pEu??&P zD5KxP1dLwt1~9~)RQ~s>>5EHCDL~60^slAm`FCCx)YP;*K$B1lA*r{tx!LFZ?b2z_ z-kRaRqld>w4-JF;z*WB9UA}NNDKhRxI`96(zVql{)>1``vTNm?Tnht(Nf1nHQ}JPk zH(RttC;xzS2Z%mdQ@4b>1(sk-As?Z_rsY{$SYW_UIj1oDU0?q)5X_sJ&LG{29RlSD zYBwx*d58~OC8x;fZ?XcM{>zz0O=N#MpVS%ehUyRy>3S>vR2AYDJr>2^l*D+ka+c6> zx63ah!w!P7_aQ?AXysdPFC^9`V6SJTrqTtT#7oitwlU7NuP`XCj-Pyb_DLOFr8$MK ze{JQ(EvH05*ROX^y4$M%KN6Z ze^jm6b3QeeK;fFFjvV$`fBpjY@**f@4|oYsMFL2DIrKc{)dwRzk3 z2|Z=R-`!nKZf+LE!LhNL>gwe5lS5iSwt|9!GZHwC-riNP1%CAT#-jXKp%MdXvh8_g zuAi8YF#r4a%U7?)->YZ!i7MgqkYN;e@HT2@{WwHLA!Gzq1tx9aIK}!Ect}EGVs}^9 zX08E{?8wVB4po>O=*&zbFk_SbXQ)rX5Vf!E5e@j~1_C3DfWeLjRu5?`2%U+pu|g>U zVD1Xm1SqaZZNQ-iM3Q0VKuM_!v!Jkpm_eJ*=5QaV~&L6WnzvFXbcuYdE#404mV)LtVYfR!PI3JA;wnkwQ3p=|g~ zN^)|EX>-#^(i|Kd>WJy1g(M&GACFN-h9F2lO!7V#F_nM9kt`roDzkkcz*M`(+U|{1-Y5O#QTeLz> zdN6QUSO{Zv%CImo+GVZn3i3Ep#Pub`J`OKid8cUJL*Y zn5ZcNfdY;L7X>UmD1$qH{ycs5>@J){P(2!f1WEgw)X5O&V}@8j>&9TQ+BV=PzGlCZ zbPg^7L_X>h+u}575V?6Xw-=kn#u>K(PBG|x7=r_X;aGE^9sL>wWb%AvWu^FSnExI= z0$_Kf%5f|qfvi9qLdXV?H%|hYxeodg&>{)o%E%j52k^2+Mn;g&#?+JJUEte!5^3mHePJ~{?m1XsENW+Q++fhr0mp9&f?z7Guj40*Ui zC>7MFb65ZdvrpO@V|(L{v8z z*TQvy$eBKc!vXJn-bTLPir9{2lsircY#$8h7f3VOB74J z>0|oE#iLA8H$>A!Y<3kmgl#Azn5SGyfe#fhsty?~L4dFRohF-a>Z5DP+XM`=5Wa5Z z??18T#0M3TM$AV>R(8g^(C22qc-6$_OZ`7%F?z~)aAEX310$ik2aCf17iTvnAzZA( zPke=(T)^X9)g!B1d_P?mSSG%TqG8`HeFNCa zwvl`5E~hNTbix-xj*cQcQ}sDp#ygz_4V-Po$D>i`VgO_b@7_idCv-X${rC_IBgD{z zXBB1stGT}a11O~pavBhD zFNu!4tGs>5Xq6u9uP}Q0l6|P!Eg({5VqDLiZ|-lMS8DknE|~as{Wzja290L-BeUA7H zIi#;D&|=WUH^52sW`LiQNB~EbYEq?QT37DM)Vo})5F06PM{#*uGZg_MM3kvTa9+SK zPrm~RNwwLN(<#oqj)zDU=hk79D-whY2jf>I4^UAeW_sV3ujiyLR4?QUc*(i!#0!kn z-SC#pC}V0CMlWXqzSa~WuzY)?qSUd^WJK-8ybQ{%R%c9DYS-!Az4Oi8DMw}H55$^Z z&f9B9$;|X?jo!@oe!6dIm9^riAN*&{2J~O{}MCU4CD~q>-k2@L}>& zc17$h9($&&gr>Wi<{US#Q(S(}W^KXu>G^50B>P#v^Q?&AE{3Y}g+DfOFx{#3w}0Y{mBswdMv{rz*KNF^-FP)p zV<)T-=BSwPPp14MWTXRax|1vC*J0sQ{!)rG53zo1qX!Xq#?qb_L5fwk%G&KL8jNIY z(!!H*JZFv?cI=94D$M$rbr~*RD_Lv0qakVgcCD4QUR_w{A$#i;|4x?6KCk}1w$RA$ zHO#qDf27>)7#qCk`21aAB6ogFca>j;z5;2s#JumV`XyZ|{*wH{)n^y3wqNBJAM#A= zHyH{1TkWeQ)6$MpCF~F@&%Gw` zXu_DDx#L{#bh0A)XQOkHL*j=CBcqbS2dQ(ja#jC53-JD2^Y|=g^ZtyrYrq6;(XO4*T9RAm^xTYvyyKSlmsgDtJN5ck&ti zo3~M&Y(gTU1j0>Qg&(g`QO(*% zA%RbFT`oWQN*E}e8q~7sasA8kGvU%J_X3zXrPtMUF0XKFIrfN()A+tiw*E3u9MiSu>Yak?A9XQghV6q57Y^T`zS<-9uuBG$ z9-^gjKJkidRqfN;V-F8v){>Oc7~Sv6hE0ZEhzdv;Ea7wmL~3aLyGAGVFC#9~yekM`P8wbot|5j(E?wHK?iRd8!N0VwtDq_9glz2krFfg%njGPPz{+V=JTm!*4oPh2k{%?(seY=Gz~rTGr_}Ld=^3hwi!D>#}+&BZK{UZq`RjNIN{Ts~%tJmUq%V zVdedCqD7fyYt0XZ&9R|IeHn7?i;X^4rLYhf!%P9wfp zQSZ@$T4!>vzH+i8I2OE70Ac7drW;PqxU`qCkrz+DQRP1B3@*I8(aB)$qg18Hm>V!7 zQ%VxweCC{*{>i}*Hp&{irExIN)v42RDo(pa6`*P}*&x4@yLi1f@;oI5h zY_F?0>rNa_l2&tR;ndaD_3uZnNo(U|8n|w!MxU!calXB-<7(E*s3di2eJ>}~`ugY~ z!7%mpXy!2NCsSd}E^#`gBm88f(g|UyO^?Xa>K+lKnXB69u+m>d=tHbsOKkPOQ(upN zQQOFpEGT!SySzwGrW_&e)vTG6@xx(4p|GRP^vzs0G+eD%qfwf|>Cw|Q$rW-0-ko``-pfz6YA~fygp+I$rTWy^|^I z#U5}ahzbb1&K8ZWK^UVK#I!eW-l$Cnz}RV;6G&c=NuVcy-<+P- zrN=-ngrDCNI81D8ELgRg_O}|;6JNaGwfjz=s^YU%k0z*kJGKTx2N+zfH~L!4_cCU@ zgNcbrY>kyhzkK-ugCUD)YH)iJ z@b}fl_q2f^!k8M61h6DvgxA=}sM5e$N(z~d{t3VN5p-PqKskY)05Y)2)}T^yrgTFF z#lq4O6am_+S7nF1fFCL*fLA)`Pz7Std)5yP4FLe-$)a!(z}}DzqyzduNlAdMQGrzKe| z$)o*K{L(3FUmG8vrHUV^8C$-3MUpRvA@hhJ5Yzxp1<)TL{dA*$&GYB|5DJ4^dK}Gt z!C~JM_k4H8N@wK^+P3rf$LxXv7#@MFlA*jjzCGOK;TjSmcbu}X2(Ff1&(aP<800vn zYh3A4x-%q$gWTP3YYHGza2pnoBd`!SQxPbTiGYAmcl-r^9RxUeIXU=E_f8vR5P1p1 zJ%A_y(|Rn0ASeP{NNxayg%}pFF^~=*W#HxE(bd<-&B_YGG8OXX<44S_DOnU|P?L!X z2l>->&Gh%ul4S7NjThyz=a)13JRj4Jz)KV=MyE9@bcyG>?|!gIlO|crDYCi z|B0QDJy~B|j4;=elWT`Tx_DW+Yz+t{4h@L`YnW~H-`Uv#4dp61Ie6So&d!F0hANm`cUwlF-=EF%PlA~F`NzvWbb^+nF>3IeBj{w>PaR{PtGzUi^+}7qYL^(kmS2%-J z^qfX)A2bX9hYx}1LF^3n?_ggaW-OdCf=ln?oO;g{28Ki1D_5?xSd_t?3*v;}g}_ui zvttb8EC8+uu{Mxc;I0nT8;5$TTlYGu^VBl@K(_^Y3fY<*aah`wkR#aM{>s|}Bd_46 zeleR9;ZE%QTpB@!n0Z-b@fKGAioI`ypSZ7qW{*{C&b5xjX-8`(5^Br z?lMFM2EQ>f9MXWD0u!tV@~x|Ddw>7z3fmjGTOBa|33USA>&vY}J?B67td$|TDKEbS zDZFp1Vt%`fdXnpWGSbpmG@LLeT;8^{12v*A zOLA*%4H7Kc@X^_m3_yBa@BlH2Rx-%R!lK-&GXZj_1m@DxkKtp&JZ{7#YB&k^Cm`ND zs|Bytfvch!n_F93OH1-O1FaT+!OVne z4F(lGSrmvi3zO?xTXfXaj$mUJ6$OB50=jW#CN1bu>`9^?zrTTOLJ*D?sC2cn`v=C- zGdf9HObtxS$e^L7E`0L_QzeaOg2O%EXs{~WggDuN)G;hu2rplyrOkw`c*Z~;nz^ft24CI@TkPtJI<~A4d>C@54;f`e28~Tf# zl4Q2FOU)ifpv6M&asZP8s;R1i&=i$32wNRVh49axrM^4PtuxukqtV-ufByb`jGC==5AFCxiVODC z6Nj>^tSjg4nYn~qkuI%C4Ze&2fqDV+DpCt7UX^?9T3aauk~@Grpi zJNxp2lpekxU4Q2k|Kde}+F2{Xu2}g6tk2)(@|c;KgRhXWo#ZTP{b;H3dKH4zH}mtG zMBEPQB19Bj8QXsrkY9fD7)wrxE=cD`l0N}~Tv72ihxh>d7G7a1FE52}`nYh6ENntN zgNLnkgq))?1hNQWSyN9Q@+sZf*Wc`TtFdLS2u5e z;0WFQQNKLq`j>Wskpa$^Ys5){0vbKoPj3w$>0Op|7sbVkijtu(xyZ;8g7DZ3;k8bQ zHLL&gqONCmZ?Z~~hnS0Dt$?f8X@}PS0*zyr3N-&y0-hhIbRBkDfhy;_IT5G&;_r^!|N8AkKB( zY={Ek&6~wT_W-PZ25xn6@zgySA0LnT6EZsm1;WGrSx4jI`o~#x6;NFG`1*>O3n|wB zmg({JM{KjrRc+_bN8yUC&eu<$mTSp#)9XUvx)>y%f^4~RWrP*FSg2(rfqVBP1?y$Y z;F!XB6@?KD&z?Tb;DCQC)ouEJ65~laI_q^6nFg(0Tk5RcN#L){*O6IBF*Ol=x0~>) zZL=w6x-3CN&VN7zp?pzOQ{%mD-@YB&0Svf6T)C-1?M)fzn;jjMEmF^)-*;zuqo4j6 zdv{mY?6kD7lh<={#Nq~ZSt2K+9`JGG>gE}?%tm-#(Lp zkgd{by`S9P-p_2eLRSHgxeWc!Fl^Fn55Yt^azsMV!dyQhDu=cf%V7OpGf}yfVQc@2FQTTUi z3PDhcz65aHYmkni2<8%X%ID9YI|*V_i16=(Z1qN7uc@i&tz|2r?$jY;5kmR+6${1a zEczdj(}oQb1t2KAE|_h0Gk)>rjg;Wx>S{FHNG9;}v(X|ah-Dl#H@8tj%H_-Q#+O$V zMp5?vY-vFeQOh@!Ha=r7C1mB~Oz$972;tI#C*w=7!56gRWNX~BQwdo1Kf zg=-BT9#Og~q-M%2Z{OaJFRHhdNYSxJ=4f`I>?mJ|jK z=9ZSjM4bz?+Q$i@@LMTLnM63sf2gYJB=|8n=9r*C?<8=p8xn&VEAktVWdaV_Kt2*i zg@Rz1UIu~d@U)O9ClKdC>LXKW`SGKxfDwUBw%rr|FH4{?L6n@OHU)YMV~{en z@R&QxGmaj`1*T7d7FCQOX!G1192}yDOoJDp9qB}@TeWHxgHjn_FmP?bkr~I%%{kNMSlZU+0Nl4K^2~I&$8Ps}+K|rLFRJRoju^~;@`1(JL zxwd!j-1+n6ew9{v3Ib+!3Dz~f@b zP2s{ohhM3wDi%XB%DZRVNm8^sPW#(3m%pa;;e&{?m}hORi#G(*f;QIVZ`E6?4H0^#sWDQo3)ipIojvwcRp2G@+qbFwbY>?qE zR(N>iK`?++RaakMnmnxoB8v1aBC*~f?w$d~H@tYIXy2xqnF49yF3wXFpVUMV##Mw6 z@&mpWa-$ubJv^clT|9zFI9VUFQd7O?5U4_^h)`)!6QexwqQ_(yNi{h<&Z4%totVq9 zgW1MN<(K$FU$J7GRQaPvZg&as=TccG(<22=p>kC4AOE8 z<#7q`xZtDaA6FK2F7#5JXP#&?P+saT;TK;r%EN767D`6gE?NZ3@*R!YFz-{*dng&2 z#>lT(5~MzARLO$}IKsSz*}^msvD-q;`S^iC?dAc&NVBuEyBjrlXLTE#xKE!yGhxzN zc~2IG8F$Kn0JC-kto5o*uo8|9q@ zrgDTAu_4L%1N;C6Z_Cf9%{lDu>>OU*M)-#uxq0)=^RDHbY*m?JN@H>iWiWFlG? zRU<}>z;nc5;X=qKU4;CP&108dd-3eqDh?>*BE3Uc;T%RPNf!?}1UsrDnVDl5QIe?& zPNPqxp5n9NSIRGl`&W1DxQ7l4?h@9g2+pu1qfVzXAt2oP*bIdOYRKREbvl>Utxv z4|4sD47**cTMdeGL+Dup1dRQ&g{S9xfz`R#%>09kRU|@}iRJ zMr%#U*lE4OrDs&Pc^GQSrAxc3UWclp-m<78?Uegey6~=jy9$_|Q4sw082dOmys$b| ze(QQAj99pcWTEpHFA~T9wpOM)o)~hAucA5XBrx9lUR8ipl7t2+ZAN3?bSkCg%i*!@ z9QvKSla^1B(?YR8+8pxx{XZ?~TWCif<|yb!opQMF$zx*ymYfP#y%@%!w&Z^^ZGDTH;32rHC8gpAC5QeEKBGFyj4mr5hEE=Qf$` zYCj>X99}2i-sg5h(&V5tpW;bF^!;KE92i+Cjv?rdOu$3BcmG^pZ$!71s-Msr`Rh}! zgpwKEN2#l0@iP@a+tO9tyLLUX^tysER6A6!*8|&|_zY2CKvCIdGIXCixhMJ3)j#Vu zEnAQ**`9CtZSI$f)jzVQOBHala*iMWGv)d3idi9fszZX^oV5prOGktUHH^D%xwH4T zAp3@nw9xp5(%QE}2g=L8$qN}G#3Ux}NL%H*VMAO(LW#4sPjv9dHyv$uf!jKb$vc;~ z>_O$sg~6kS%gm%|OP;tqIB(p~?3_xCB*pC}O^bScy7V)Bad76z;G88LR$(FgZk)a2 zs6|Xdjo~R){$|eM&A!bGsf1gmsVUXf0IuXZH@AMV&Fyu^Dej~SfFOeF9;pcV zt(75oA5KQ;0h5&mRjL5uABE)GfVCJjl3drgxnVP;<58hcfsc95fmuIcWt_A>vmBg? zcK8$)u2E6lglm3{@8QID7IS=On3_`UP!8)PO@%!mZrKo{1Oo{XgkThGdEd3OfbMs? zv2iyk`FHr&ey^G?>rWn~mCe0;nNhkI`9KJTtd(+l8AL00om>j;g429N%#M3cj-mbq z82leL*stGS{P?W1?Z5y@?8Gytg2*#3-Mej3>0@D(URIL_BD3A;1-7|ufx(7`hQvfa zKV3l?AVgnZUsqTDdIyv6i|cDGAjRq=y?y#L!pg(d6;jP{%Av+3kRwS}ryXyG0eQmD zX!5(3pP!$VmBnpNk>H=o$;gDI{Z3!~p0-`TS&D1Q+PPpJZi(cW}B+@($e(fcQb4zzsQC=5x*Oqj5*q_$i4?mO?iCjpC@j)pci zWypR$JWLL$y46+4hYp^}>|RbTGHja)VyaaS4h~dMxC$qW=R|FkR+5JUDC%@h#&3En zAF)zv15jOYvFn;Ox3t{J8aVu^EB<7(j#Fk}P!JCQWU6fL^WU9-sHPwOL^+gh7L9$@ z*kw0x2b)oUd`Yfcfz}y&r@n>cbLU}r>lI8;Qqw@6J5g8%r@t8&6~u^1!u5e1-j%uKc3cJ{LO z$KKe`wPs;aRiNc;y+Zz3dHLMMi}e9v-o1N=*y+NhOVla~f;DB@fC1gb0`fMPFo=xF zy02fp%+Aal(7*pgl9p_2KOw)mQA-N+u-a31Yo8%YudJ^Xhr~O$Ub{nL1rOAX$_x4( zZ=4DpAysWM_vXn-=pHsu?ZApst$9X!l9L#uNOINwFiXZe3jEZ0N$hYg|FdjzQt zW;T6O$bdrpnqabN9SRk8@tsmQktsx#`*J4 zDB0-aj8h-IejR|u5Nr{WEBenNdOmf)J$zKHb!Efa74-S=DFtaHDYXIj-@bhdb-DVQ zf3qO)8^XX^+L#Dqr!)3~fE)s-i4*wZ#bs#U^wjt7&!UTk1A;(m0J!3mDX@*P^gG`K zT+mM>v3^kjF0R)f*86Po@o8b`fYTHOmW`~IDth>x78!^2HfT(m^!(19j7zSOb$^d; z0%gWRmm~|&FXn%N+*X*~>w|a&LYbI44YWX-<~BhcUBEH2w$=jTBgFUK+T2nf-v25i?vMe}7&d zti#pr?z^1ilS)mU)CzYEgt9aR4c~_kALz0FAx`AT{{8vd1Lt{f_{F%NqZg$jS@rnH zOmtz;z4%D^iL*ho;VwbE#g2=PY14um01k8`US4PDb;*wu#zUrVny|jc{K+Of(?9`} zrv^NG^9JY=x8fR#B!n`6bWoW2UJAP}LQUdG&`t2S$gX*YfXd_u!tH2C@RKIBlKMHF zCZDxZQ*xW$FrK251wEnuo3!>7C3p+L2xx{AKkJ}w*{fFyPe&ZG>MqdkYsv58Gw3EE zF)ML`DDU&J96Q$6+Z%H5D}W2@ET1#>fn0Z90j2+wAx+=E$K(xb5Bvw& z{x>Dd3V9vwD!8+qF=sX{jYP9Wsg1D)j`ezng8Ae|=$&5uf4Kmmp#lf@%F``QxVIMY zwMmJIQA4H$qi+D@0-uQ*G|9@$Yy>94M7YkI9l;v-MdQHhp(-kW`gF!<^Xb#K7^iZT zDHlM~tv-ILP&N1w*SD*g()k$Q-= zGvzK5_Lzq|oZ_Rx&`$DMl#_gf@gB2tRgHd%y$kQ(-(`_O0-ZZ^=19WyxN&E0+)&!k zfA_9kkLDssReyMgJ85Hc=HyAM*|Wbg#bX*rLF}$GMOm3~=aD1fO>y0)94&K!(dn-9 zn=|09vj|YNyxbe}Q?y{&r%piJg*ZQn+;JR|oOTiX)bp)YqAb}3Glx&FE= zSR+nS1v^gljqe&3WMKMBaDq_0p{34-wjxg30GQzBd|j?|K7c0p#2608gopP*7o1NlXC<9MT3ij2hWgdaiG-Q_|c2F zQDVyYI;{XzJwX_gFsfmD${KfR{m79j zC1J*?DZUdEO8D^p!aWDsdRXO20Y?+NFD@>?avJKmG8Z9~!;G`K!tQ5fWt*%0X3i`| zm9%*N{K#(ut}gAiW%GQS$4{P^^y(v6+t?VTTCxQq!lXI-MpJkl!%rR*J|BKIlRexq^t$;5d0!A z@EYx9Z0ut?Le3BuTNK7f=X!vxzS5i}O9oFJZ9b~@4m&y|!iW+eBGlSJ6e0w*!##JK z!rj*yH%`n)4bifxiiCBz%7oFZWPBvETUl5#FSj# zmt@4_DMPQ!Qk#CX+tsC*LF-9?t?-m8DG2H>Ufk&BmSe4cJS$6jyGdOI zCWSLhOc>fRO~K}pj`)7!#GKUBP6ekc(gjAo*lb`QS_Gpr+ctZj@-{RIG*AbXw^5~~ z{3W=5{qp60QIU8|?izsL3ZAi~L>LMJ9zw~m4vvG; z6f-lkL4%S|G!qk;Opy#fvBhKC?%i~x4K&A8^-y>LN*GaEhYj76MGyPZH|rs8(#c6l zCAI5`t>}`q9rFb(>+K6EUhCmuHqBF!eFgY&V`6dk!iAv5_g&jI4D1`G z98+{}l9%G8GiP98GvcWxlJn8jRAcTj@p{rFe7lW}!=j`6>@5y#soTRxPQbR)n9&1a z{efAlRw|J&SSVcg-%ka+D^>nnb*kP?c%hIx^loV<~ zWcBD!cbx9#dxCJ~oTpWd(Kbc^rSBg}BcNMJ^S%PRs z9@zl_o?|OkmxcTbJTMFQqpNJC7?ThHEX&y;j?t>Te7VrZM&(o|LA*T!vrKXsZs%cF zd()-!5ddX~N%U=j|0n{2(VTK~35(NyP25X(c;?&M1<88YFz_d zC-!wUk_M3Q(F=m*3)qF2>;NNn;sBa=Dj5P9gAGzbhOO)R_1X>xp7M_wWRfP?Jp%4W zC_`Y6A)VI)G4fr{!QSUOb$Sd0>C z7(ggca|qn*(|LnT;o7xtFeI$49i5!2-oJkwuvTSEf}4j2nkm&i=IBct+-)kXNbJZQ zC;~Qqd$axcWS<=zF<1#iLb0rsvRNq_%?y?@!Ng4djxg}z%LsmsaMM_7|Y4M-Gd^u1?MjJSf3qTKX@zN!D11z&ydL!`4e{itu z_1$G%oL8&>DWbU2 zrmySu^2Lk&B!Y^HZ{!OX7eDODbdtU@7Dr(SD#=!g)y~eR?T<4wJh%W3F7(yB8uqp9b;T)nEd=_GA>x890`7V)~Ku&~ZlIUrozQ#8Fa33|YsTj%3g zPxZI+4u0k0?wTzbj71I{SjgdI2xMXsPWuPc076T8ru}fC?r?P_-Z(5?48RUfb}S_& z^wkfTL}*F=(NYuXXl1b^qr|^Ri%59DHw*d?ouPOHMoOlejO;VyU*XjqAznAn&M|Rw z`yAG_r-ARtfAU4Tzq)(!dVvpuzS#x_U(L&IQ2SrWGg%*B0-uv&b4PClyAsq!BmjHvc*Nmr< zTsG$u&>fQOhTdQQMt@(XA)bM&=H2Ot@$SQje%hAGy`$Tolm_Fu-8AQ*ARL(WEZ|%p zMPRT z*4#}A0hidvqp*pGyonGYy3L7WH6DXYBZ-$ejeGyOG6T13mzd~(slt^NA; zRY^YbPi66@CpP76mG93TihtVo-|TrUsfl#vmfipTIHOnjbY=FOfBCQ9=j~;V$-ug_ zbkC@P;nvF1P_6QpeH_);{-u}2Cg@Fdk{+*=6p*jRA?xAb`&r{LtYG}1z zgLF@B^VB_y$M|vu)|=@-!jh3m6XSjUt@V9H5>1D)gt~&suW(9ey4Cui>k>7MwGFA) z#ksRytL|21=}BM7pnFuY=Bh(%SLh0cyOq4$bEHjf@Xy6pbMY`c5ajgNMX2nza3bWh(`+r1lpo`?B1JO}6?N;2I0_Vw%F0Ry1v#1zg;{;fVI&v*2X z;2?#81HUqEgqg#Q;$uxlyGa4bAz}>lC|D;csQJTeAFZfr;?VU|J*Pfyc%~6mUFQ@( z=Zxua_70g=eh66BH$0}W`uc)iIZ3x=8uR>L_n4ubTDQP#TIVJS+a6Ye;H<5Fv&Uv@ zQ)$Dn0;N62&EC9j=wt3_tQy}Vz*!KSqGzf8dzv0E1a7v|R0*Cc)f6FKyw|DY`R%?J zF5G&dm42x2(M@`@rq2HL=%&Wv9@h?y_|nkjRnW>&0OmqWd^|QBF0|}B1L42s(i2zy$?hiAI;{Ty7@f=8Q0|oIF=TS8K zT|<_wi0bL-v3a^tmDGVWy8&U)I+2U$~EHrHFWD_Rh zl=fVl1BVay4n7pA(6=u|dry^JOxbvm032oM8~I5!}?cfSn!+&D&-{1QuD?R0NNTwWzlZ^@#EPxzzPEKf(YQ=s`Xds zs>Oa68XC*B54db-B>#sg1Ge7)f1&~-)x#5EUl-58#MJcO>Lqj4_fbDGHln4^h7!T3 z+GWkqy~QR|Khy4z?^)Hh*)YF$JsZp}0_K8Y2>AdJyePnLn{cVye5e$Ql&|95ZU7$# zW}b3hM!ik<1Q4D|J{fy@i1nH$mgffMa9 zNAbAY)s@`%i%F>*3K7BS}yCAr|{~R^^!k8&`&!4DO`D`25w|2^diQZk$HjJH=Z2sZ6{^H7t zAE!*z?NV;1Xw}Pi?Q+2XZCPg0!`u5kH*E_(>Cn)o%BBKcJ-wlW3UA$7&c{q81;c5O?Vhq;FMN+4J=I7{miX zToCvWC{jtXkoM{*KOa5PqD)4*h5<*)$POIfLbi!P33evly?gY^m4U|F4jY4F(Gt@N z(W)aa1v5|^J{-Ik+nhW+3AsxkJdjZ)2a5AN%?eqE*9@c57(P$g&z~FpN0d=d7Zj`@ ziEw~@^rjjZD3zA88_0O59wkHRaa)r0B5C z_F+uH9ER{g?Jb7*WVSo)KYCP6aKvt)Xhu8rF>smbNbsR1W<6popK_0NkfcXcuTVe5 zQ&Us$hr&iU74XHJHxrkB@U^|Aq)g7O?j&%`3cUVM7IGL#wN#`;=A!0$t^@@=5iFS4 z1JK3-HR=*AsZnFbP$f{F#lgOz^l1?F0?&m(R1!Zr3BwR_GsTIAhX(_$?PG4=x;0cw zOKZXeJF2Ohj+Aqj3m5wHrQ$Im3$td#aOupnV} zcF-6(lA-hT88f>2FU~%8%vOCLgk(lD4AbfFe+FB3gh=*@w8?(L^I?A+mzNxrxPO1q z-{D$cDsRLk#KpC@wWYLYu$&7>8@P=Sa;qr?%hXUX{?&s z(PPJ0;Zb?nD_tzbE=l;G&i5~3aAI%!q7#;CTSkq!kU67Q%sZ2o`uNXD3!VD;$kvC3 zZJ0W$=FC3zq{5u;zg-8|Ne<{WZ>esn<7t22jh%d7xs3MNZty{}B=b}I715lx>t6R; z7rX4PWk~6Ts&^I7b`7j8gTT;5n3I}@;aYiFBAe zD7TY>-zvjaY%Iv#m@qIidc-j8MmD4D$(zK6g*9tpCsosuS*XS`Q7rljZ*QoyD5hZi zRok|l+<{HQnVgLqKR&&@zA~g^YE*UBL}=7{hbmsZTI1x@*w(m0GD1dHR&`Fl$pVgS z`!tzombuX3QwZ|$e)+l{6z#x`&sR3WWrGUr|4{T)ev@~~T#?v^562(4PiTQ?VyPMnPlAeO=~4&({DF2~SuI!Sy%xH+M>;lWr7F#A#ua|0O+l{rRCB&YkWkTf19%8`u9K%o`jusM+bA zteO+Sv&YYiVhLnH2n@0!$Wk{R0DcCLJA5g~o6cE!Z-p)a3qT{~u@rbd7B zBhrFlrVW~c&I>KE<#FX#-n?~oVBxV9}$F4 zlXXMCdnxs?9(nO@$F%B$7B|;h__@g>?%cRYVa+IM>6*O$dW%N>xccyKbgNV`)SPsc z{=K(uv74?bhY1xEiwL^I@Ni`_64RfjD;}_U>cWMIRN_oe2;0;-aK`u9X~4l8t~~}& zfr;_PjXFlsgc=5t;DaVhBB^zNB^lwk!E>UH1}ypawa;!jBN0G{IMe|W!ul6#&81E7 z)uG&=$K_x1;KWo3<6+P>H*EO939PXwUUhDRj}K4#@1H-o(Wkt>=L|zn zF)Y#vC4re$l(g}L@DZnWR0MD1hysLRzUGoL+u*(0G}1(FZtf)vdC~A|X<08?6i{)^ zpP_RKdAz%!}`V5jSwWQ-)S)RqJn8%mql(UI%!mEbjS3$HbYk^NGi{V`*@m| z+~zgH9HPvl#m%-`GiOe8TH1RGTc8gv>7%bP=Bm7I&{x(6Wm>Zob~KGW;K#S-eQBSK z^KggYV@5!R;U7pCB_Yeun%KKU?l_f6A ze*Keh*GP0*SJAoD;$EGs#k+z5hnDuXkkXs6Yq}6xE0~Ut-fv)0QaJODdD1^O4MBM4 zn0oMR>ErvM?D!B}{@Cr~+xeAwe$jK@L!!mG)Rm;?=oqS>$E3QYMM_fgDs?Yl$JVXg zwa>=*9tOGt2V?~iH8l_x*cdPi0cgn7O-~*@f{*bD{_MPYKfb&iUvubVZ7tTdk;`qa zT)xcHczex%JctuE==d2tDKDo{R#@-WSuhmwJ(zuocTz20e9il#6(0(*+HlV7*-#X< zw6(>fNMZ@aOw3$y%BIyKpADBOrj7^9x(vO7^=iu9eE=}1#Fs!{F2Md^fxqTpm{Ndr z7>(kE1lY1(_aJ0JT7!7Pg z{KgL-UdLw!xX6$-@HPk>yX|PX*MXPwb{S(k50a5!uoE*(<&YI#DtEmOT4jR6f*+^* zaDib=Lek_0@Ld5j?KL?T7b)C-{(Mke-1y_i+I8z56cri%Th0w3p7NZjB=G1N4EXQv z=;%q*fY(;jZqGX7!-v1s)~*!K2sAM{e)mqyT$0&N6u9NAO()Dy07l2i5IZUf2pS+( zInb@(mpf|3)QWez4w|%%tytnC`n#(jPrt@{2atu-2I8})81m{L;8QpX88>bisHkkd zvm8VaFggYg(B&BS=pvu7AWnKF}NWDk>+EUa(}m(S`(AkdrX9^FnYJ_(|PrBS6gy? z9rXMk8$HwPcBr#ZG`gz1PF`i0=Ruc-TT`w!KHFGaZR~V9dK2}w!B@?8%0>1K2Kf62 zUDf<&*~TCqn@LtA8u%I+pi2R@86PWneivdq+Rcv`>Z&F1$zu6 zljmO~uUOkB%}KZu(t~Ll<`{-6zIK>nH;x7aB~m2IA|@11>->1cp75G#&=U!Po&D5>rj~zEG5R~tTr8L8ulb*%fOD0a~ zp%~H;GET7k9B{)z#$44pF6xI;Q$Y{wiZ0W?x-MFsU7uAxvljVD`{irPE>=7#R_bs0 z@~w18rR^$d8NL-w@#$TS~YxjMN@ zQr0k2h-l4{k&xE8xIIHRWp3r)A%3rB|Iv?_A^oo8bC}ndslyJ%b%^(XP6-G&ppv<) zebaKjN5|ZjxZ!7H4+icSpItPpWvZb_UE6itp)X5Hhbm{)Z~k~R=YCJ?;evdM=)uvR zxm_QJls^62f3qQz9to%d0B`5t)rM#?^IX*+8bPj}-t(@9%Z}rMFequJiILr(UQiGL znL4KWi-kj?pc5x1ZP3>17W{F%a<^^0l};KmHJ#qZvunrZz)rWl(xNX!Fw`#lPi5(>CGRp(JY!~F#n_@s zH~3wCC&Dl|a(b4B5GuBln>%+ry`-SLZ0`>MAL1Pe*k!VZ3Gtio%V>QHjNF=Ul_I z*wXSd?WOr?cs%JD85gj*LA&tx%BvBa!+w7n?VF@Ki|$B#*E%}2&0*|1nH)P&N%;50 zF&6fr_x99$i@4M_-=ZV&;*73>%>@uB(9}bTboH;84{n60NMrE4`+*_sELtj zT^j@&BawJj@S{g;e9yMiNGh#3v!Z+Wci?i6{HA>@cx-}bk_R+j96MZJ2|R_-fgtS8b?Rg_8|)vSBn6@{LJ>wIp**1`~jQFveSq# zrh}E)c}KhkAo=04(~h0=;A=30UPBxH?s?b3+7wwD2*x<%2PR$RFJD4Hs3}`zeS1G8 z1En(Ue78a`lD>c20LBkbHuJb*EHXH+MNi{zIu=zoldS*rYFX>CWmeYh@rt)9l`3wj zui1gLD^$6rb8BR0Ii30a1<$75z8|Nj4s7mtxK}!V#(>lNvKnSf-I3Wd?D<%Q-D1`Z zw}atpnzFykG_jZqmGAFgztUE|WS~RA3zS8*r@y{uoWP%n^dri}hAGlOSMIE(P4~pL$(>;^lfavf>Sh)CUi65Ur-sT&J zOu8zxJ>EGcwp0D@-y4MjtuTurSv#k1>ox!GpXd#5@4fHX|K_@SgvZggjxkfELodl_ z`pF-Ckk&f0)5IpQFEPIo@jT;cBnd@X;%;(JQGWhA2EIV9_+~0^Z9etF5S{@3ffQ|* z&YhcX^qxlw{DI(IVesG}2sQqVP^?X(mHwi;&oZs{bGQ{hRU`h{2II!uo-Ylyt;SxO)9hZ6Jo+YQS&vTdEH|Gz;PQF# z?EdlE5ltRJR_6nS*e;?}KzJi*^JP?G*7~*v$(3*4o;Z0jYUp5Q_UBz9pco&rGPCT1 z`@CWyB`1?XqOk|tRqLqfnLH}?Ud^47*?#22iP>N>L5dQ>_x3**FX}*iTJeP34C z&8SJGX65;Y(Z;Dkcdexf6orj*26X8ew3bQa-{_kU#7}%^i0wt?a0#O#=f^*PMjg|8 z@VeW$1tzI;Q=&A^PnsW|OsHYL1YYa~)pb+*lxNHF(OEYRxOwcGBK zq0cN-CW=_rh!&3!=&#O{3aj91S*FZ zUWBKE2$}8VZnC2rPC>vtO*hQ+@8b-0*`t}6A>VAd^{2k4p8dhB94VXJ&WYY1W}Odw zuxx8d?YCDKKBuex7)U33zobNLD1G8Yf60jCTwQW|m;62};FQB&{7;kKS-8AdT&gc! zsuW$gbJs4-kt3bXy#%nM(zl#FyVg^eW}G4t#DB>}_krP1j{<-(XyPQ{gDw7{s93<3>=2;Vdd&tF`?y>2d@PoC|UC146&-RMXA&nAw>rlzawYbEbls_kXl(s*25%gdYb zlhBPK_yex$OqkFDipBYc+RjM@;R5buVLj&ikT%1l4UoqxD=V`j_v~4j;q&b6*>^4_ z?t*5KWzr@8d57i&w4TKE)lfO+Bjkc=C|#dF=OQa~ zDu|>%z`Ky2G8?Ugnwsg18Dc|7-V~RSom}Fj8dN%43$k-^wn*~^kxOhg^kiDdu!-wO zb3&yHHH6HbqkXKuB8D166r1wI%JBr-{N2NmTwq4vDui; z&>p$|fMtP@cf|8`Tz^^QMTHD$;!DN#K^x-XggCpqDUN3aVE$OHIwR4x2| zL|uZRt*uN#37f|8SWVRR0&iIhyMa!(ArM~o&5zB{ut-I$W(Aw3^#b-LC-EWI*3~ia zy~}(CwGy3Ke>pj7q{sC>>`Bfnw>f_hs`J$=SHy-KfKQN?Gp!CWu%5QQW}Bh1p+CnK zY`>GhZV{@VLu@n>bfZu>V8Lt;87w=SY65hGiOY5qZ8*HJmcYn4QXgw-1i|W%9@x$u zXYC$h@}F+hSn@zV;MKfWZ@V*m zo38(+CSd=p4?J}*c<4L{#)7$5uQD6T$Jd^dLuA z=_?*&)EYCt7mxLr(2!G~?BQQR8jvU}FNat|4zS3eP7^}Wr(vxNXEomJNO?V`!Rm9{ zBV4j_e(zRl-}+6WM{$P&8#34$@c!|!M7<%WRE7jEU-pud^~id`p!eTit?q#Na05y< z->8CtCAk<(0`^PPu3bK!p5I$qprT=N{)i%Z*)=~W#{3Yh*SM6bTC&VJ!SmVG0h z`Uf)|oDNC=1{ipY?1#*M_^_BO_sbb>9&2(A#>9LFO6CEPmfq|B!}I|;pCce7^%(zx zi?SV52Z;#6IwWImc+|HDbL^Hv#Tm?EManMSd-Myd5`2y^!e2P$i*M? zzN){P?{#p^A*;{3PLyvw^mDpp=I9%D9kR<0M9m$XdB=6$QXRV)!=C5M)`VTVE!iY1 z9b390`O1(7K82Ro<&HLLRCF|SIrsA5)Bv5Bj)PObv?qUYKl~B3*T@~Q#m)@SRA<5@ zMsULkfeRGWUYmdsDgbQj!$?t1OsDS&jIa;SN^YOyHv8j=liNmCt3mf?2Ky_1A_6AM*C>~*-TdvAxXpvRI~5=m!_AoQD;Y5>8iOJahFvKw5&4wjqM$o zGecDDtj!{5JR_Sg%Rg+^Nt&Hr_Re5WN$qlZ(-ZUOhAdVnIV(HEMz(fiZhuKAaym)9 zD%~|?C9gG&8bb*xKA)W$&@M?BWj+1W_%BirT5oIJkFZSW*mK2kSFPFB6Yn3NoBXol z`=)b1H3oP4`~UM9)@9+lPC`dO^>%%&f1QQ>j>_(x1T}fT(<0&B$_cYufBZO5y2?lz zZXt5LYq{%lTMOi7mouk4eM4X9NG2!|_ZuqCA$JB-4j8BA*P(WXkd4qh@A@c_F)A1q zIY30L`7yVPpmg%pD+k4M25s;7&^R$?n2thOVNy9(O+nu61+p*AlHR>~jg#8a>ZgAg zNRLfN<}2Bl^mfT;8`-Uqk>3t)*e3mp28y1%q@?8Q#@vK4orLdM45ST3H0L)kU&WOr zOc*nlCODDwI8>d)SF!>N-cDS)MLjMsxOm- z(84*+&NXkx!r8aym98pPaDDhUvoY$e?yY$B9glTp*Z^6HRHBH&1>UiApBwaG^7(!9 zB3M&xJ!}qB29)@i98xiAbT>gi@w zYw0oLwX~E)VR?CZ>?u;0l`25Ss~%@@V*V73dspACJLqzA)N?`b^Py+VYkn|Al!~+! zydf7RwaR%Y8Tb?bH8A^osmOcIie8HZK}#-89W)dDAsz-LqEyC~H(0#EvWe>*Hq|@y zwXSp*-py8>sivw5S*^R@j_wO43WB~!`l2B8lL{X# z|K|PrwQf-Qu)lU1Q6I0j=o$#ZE@Rz84}MloiIsUjv{&D1-7kl|&#vA0`}=p%%-9Nx z37rHX#MH(P)c@DC1(_X|t-X0W-;2V!qx<;3uS~0ek3U_z6xsAL1+y1yj16f0Mn|J= zCEeDFXq(|;K*q7HS?TH3bj^Ko`=?hu?XI+Ol;Ed)SK%V2K#y5g=30IfJORp_oXTx4 z#ed6BYwj%kF))&r%P(o=LNOoe<{*(uLdozGy5YYUJ|1c)U!XN6p=U>f%!pyPfBPTK zbanWZtX6V%!tCm~-illI4tRgpQ*rC4n|F@f-F3|RSmD==ma`!x^Ve!st1KJzxahBU z?}VjY-tu;GZw$}@xP5FM|9F4v^1#P&JJQ&>HKw3hzvQe{=H8F~QPn(RWl`Jmj2}je z$0YdlS1hc0b;M6K=!jDE)@t2^l8D#E&3|Bc!MQJNE9vUEJQb(tf;0QXjWDGI-`~^I z6HcbNc~PHfe^Scp5p^>y?wMXz=cCt0Nq^hiqQbj(Q3lhQyl}rr^(2m(cq1_jz|{pw z1UqCk)YUKJFO7SHMF#t13p4$iibt>=jR7y&2YaRk4i0OymzaPZ{yCM1#nb~G^(q|(}344nZ6xRXxOk4@K? zo+`{NSZEKuAjeLFO`7y5cz=jV;3%6DkNSRnMcTi(FApaNho01LQK78oGzks%? z54j6&`uEN=h%8)SNVN14=!qxeAV#BQnr(i z85@6mdvS)5fGqZXW@d~P*%4)+r6v9GclUSW#y_i&D>y zmkpborvVchx$4)AphQe^q$LbKQIRdChX7*x|Fkn<%E0`Dya$#6jKKJz1sEd0K(#(3yg{&UtxIo`LkuLsxb!aeW?K6W2Vohodb+WJzk!z z#XWla&YcWy^!U^`Pna4kwYOJ4Bvx}mZL2)9!ffTr;SZa6jXmWw;eH?=KccQNe0Xg} z!cs;c=U#tBExiUHm8V7rz+t49!z2Yp*2%5PqBoILx|xSJtwHPhxO;@O!+Do7Ob)IA zRZwtJyT+=*vLD>P|2)>_N}FKiLHMFbPL6hwzK}D};qk{w_+F>r zmS`AZY*bFUt7TkN<;hl_?P=TcPO8q_kw&JdB#E;`)%~S*{TKvcJJP;voMI%+I~1j& zSId1HHMcg}Li2>Rx+n}{bWBMt!!M=i)xeI7rGUnejL6OKg5?K|TklYBl?hmJu-a(x z!^^`ZjNa)D?!EQYsje02Bc1k>xG)iEDO~Py z*?319`y(pY&)_Hib@}R5hf~kFozcIxRRpyz*?C92M1kw;I%26K)f-|M7)k_RM$fxV z**>jD&dJEQdxvjgF!*rtQ!dk!C$i+|oi?%&)8>Rruk)sbG}j7q^J zAJ>16IVg5xyZLqFq7%=+f}8XkHs(%M>Z4?l8vx$GN@*|DanBg+dIlPawzE|ut)vzy z(-WqAxU_gpFcS5ZfWm&6!F9zGPbb*D{5OP_r9&1d zysg%?QrnB>7FYh!jDS>C!&PUxYW0oYd1dj*T(g&g-@#oQ+(&6@&dRm}!C%A{?_!J% z*$MgV(Ieo8c@*{&6q5bkZ>LnOUpBOQ*fd8{h4inTgPjiFX9E_Hz0FWpg9is6*#XpgmPFW75}36|hD zOqSf;uWkJ50C9mDZf2}h&z^I$?HU^!|A*r`T=q=8>4-%5867`Z4r$;|$9_}%UJSyk zFS8PuXi5CUauI0!zunh7hlkA+=rpudx(^BqM~ob~JuR!LX?^b7!?;KM-lSN_XlTn# zLunIS!RF^2LcRR0AHWbjHVxzSoTnIY*&}DX@CpiW?%Q?R2Xv!Y{91+5w;&kLm5=P21 zIZ%%)D?5c5_a{dSBiTLt`YTbnPp*ofKi|>n zeU~TCpKqOF{lEYIU#GHvujv2pT11P4&>r1x$q9~On@@M;kN@@BHm~YLSZ-B@z^|Dz LtV~WCJBIxqP8=-K literal 0 HcmV?d00001 diff --git a/graphcast/docs/tpu_types.png b/graphcast/docs/tpu_types.png new file mode 100644 index 0000000000000000000000000000000000000000..954c20fd47a60d9395248d3f734e0c967b268525 GIT binary patch literal 27827 zcmbrm2{e^$`2YKuH<=S9V5&ua}e1rj1!A_PH56z|{DLJ(~DBUTIw z^FIQEiAMMzj;pMq4hn@DomQWMzef4!=(}k>^ki~!akP4BXUXK|?PSR$qoHvNK~V9! z5z0hOH;M1cH$9Cm8rxJN^XaTBJf%3KiTU-AA z#P+Fk6}=*h4CS?rjkekNYS*~iv)*04dw1D0gUDsG*QeLru*qdn<5w;t@Yjg%2)yb} zg^vW2=VshQn1VIvUt%H5!OQH#2#Q%p;UOZ8`Y1|-gUF)#HUIa&B4i;zyWvI&?@UWL zALzIIC+GZhBhb{;R7mvOqZ0AmKiYR+5g)D;HPoMMDc6~HM|6~~uC0xi+YaKB(7G-E zQ52DqkT~5PG$RP&(EeP-KoCT7Rq$l1gHJDScQG|^p-PM{Ua!bh!sbtQR%*ae2xCB1 z8HWt&qm7mf!p_NgG#_WyE|sMm$BRV9#Z@|t*HAL@=ocf%(f&q*%Y6Iq-x^5qu-#Qb zJ0z&&aciuw$nD$Pt&D+Rq8FothdgJ>*DrZ=dbDd&?0idX zsK`u;72!7hdhm^a(QdNd-~9$5bMQKv2wPf1+_cVTyvzPJKmU(+jATf|`I(sVQMc4NF~SIT7k~T7kjOe0RV`P?2iB_xuyMsjtzs+>cbLn*G8mX@+-Dn!sdRDdJW zY1-V}?6*JeiQIIY_re zICbI3({ujCslvk;V_rXEUZUTXBy!t%=AOq_&y9&nSEIzvAEsaIhNJ2(3t0c&n(r{U zi3GtyOAUW6bS5@6HDRSW&;Br~bnriz4!NP8BL1|O#p!8^OQ(IzU`*TdIBtWPr6ody zUexXK9h*P7x_L7*GeZg(CovojC%%{lw&LXdb%j_SqtC`&BN>KYtfyyZaS+1-kG<95 zcj8;2OsI7B*uK1H;6vE9iJR-8!2>}hS@xf|Ycj+~6H8oPPA+XxsxF;4CI*JJU zTh`G>CB>Bt3=D;-y(`Dr1Y`_f!BNqedQErmHDEm7@VD_=&q=G}%iy%1=aL~@4+^<& za6ORZyrqvq(7i`*@5#$sSy(i;wlYfk9e`_lGM24K2TR*Me z#>CaMc6@zp%|CoH%VlhAjMIBzyIV@i`ght3_+T-QRd67Lq+K_=c7TEYuYiVd`fK$LS|MlTk9G{t9L~Bb+fk7EQLd|DtY=s=G z*KT}oYI^bf`KpJQn9HR9ehE$fb}Swa@ayzFG}5-Le; z7aX2!cVn54e0drlAHN#fW-(IlFJ^3uauKGR{_m@Oi<0)OeEgFK=_M%U;9R5rL4 zDJxabae@w!m6eq|n8M-XD|+*Wn01ih)}wpz4i4gu+I5kQm6e26AXMuJ?v%q&u}r3x2puV+5gYhHDllNM(1d`5mWqPn z8+9QPR5{_32Obw9CgFt(wJ$gBxXk^0VCwzfwt-2phs}gt;_6Bk%SwV<+i)-w{`%q# z$@7zK!NdZEcMNdCi|Y1hkv9wywBhHdj|m>yW%#E4ySm8Vk@CF0CGLOkOCu8dZ3;Fe z+e7Ie(2IFk++B5?``eR|S%g`fakoLKh2rar_{5hZ%sFku0_5R1y zR8%?#8O*^N{nvJ^3>l)Lqj{*ql`_LZL)m*YqLY*D`Qt4NA|fIRe;o;?v!8)=f)f!) z>;z}u)6=7xAds`knb+?FiN)P+BBCJKY_9eqRH^s84-UtgJSo*pdSXQPb6T1?TyRGmAD^=yoC7#p zb=MVTQ6F8*k?cpU?-<{`eJkO$X@0!7*7W_m&+f9W>g`Ah*}sDSXYy8=0YCfyK8M4k zeY%4Ih;RcjoPX29#Jo81Y9d#gUpw%g>zWU5R*fS;@si+A=No)&mi7wdCU;Ok)D z`iu9tKAJoYo$Hm}za>hdz~LSy(9P{r^-q*XfkT7nGj^6m|evF z;g==DmmQq=mjN=rk@*FmAi|v0n;M4lu_}9vLIfV?H>O^@V$F6Fd_`a0JQxmocK)?Q zpy1~LdCD;+}VgS3kdQiw${V^{L3&b1=se!ItLG<>U%R@V>iFQE{|Vq~Iu~BaP0u2%TE4 zCIt~w$Ab-dl-;SA(ASg-*C2D=*z?d^6}Ai08$9N&kb%eM`rR^Xd*JA%*m5W6>PrfP z&$O5R?atMS>xhy(`8>7GcHl*XFL9%T%ED>TIc421t;ptA(=!eAP%)FA6o@Q4Vr;wD z5X(-4h^wEy=?n}sQHwr0Sqbku59PUh`?&XP7EfHAh-XVV^|Re@obOS(0{ZuCx4<=k8V#hp22_pdkwbL0@X*0NAF!WNB#v)#t!1@q#rT0Q3bhJqEmBh z{n8M$)(pOKMS|cDDk`zRxqA^WrOu7 zM4#O#l3W~>>!#rGxqQ0wXY6x_BtKrK2M+B<_}(ct&NV^ z$kuYs{dPm9>R*r^94*Vj?#}kMF$WzSr_R%TM!vsiias|L#=q!fO)DiKgIfj>rv9BZ z_TS`vE1UzfLl|RnRC2hr71JFIN!?hA8Wn}7u($;|N~Cx8<}43wkjSO}u7|SHvGa)I zLoLtKxI|YEiyUVZUS=5p_N+4;%N!%ny9kA?&Ic(gU^>a#(-rnct- z&r<}Rep>XR{M+VV&av`R0lQQ+qR4ZO$1kT|rdwD1Pe|`xfOz+GLN?t0XD6d0+6%p; z?an?XpEoE8hkEiDSu zgikiJ$ElQ?>pK5%h8Zox-&r-u1Lh8kj5=D|hrZ-EtIzv#O7|g_2t=4^NS(FgeBq4ux(o;66pURx($K+c@vU(ebzPGP zv0?aw{ZWUyg}}4zTkeLH7h%yXEM6mk2%OBn=vpE^aL4sQ1lmi6!}4KJT~NCN>aKKD&a$DqOl*-Y9_Ov& zG#xT5^kD_Ry8Y$G!1L#3=MEEswf8xMIsRPG2t4)(GQE(l{oR;>S*JAyMabHIpp;|Q zwqqdQ=}hN)?1->Syhn|I(DW`6bow~Tky|PW0{MQQR>R36-ZM*Xbvhj65uWh`PF72tM!4C|~ zD#yEkC=s67=%(=xn$*%YW->5`UHgx3UDE$tKsAHqP|U27Gb1H-uz_OKJ87lGpe-V$ z8^tV&?K+4#gv!O)osnh9bKa3Q8}cZY652*pdzsfz_C;u52idZweP?4@7A0tO|06Yk zL1-21J%oPl^7x>Q4Th&Md{1APA*}n z>8MHlJrcw<5L}~)LoWVQLd$vWt0b3=$c-nN zVKLdU&TH*>L0f{shm{}9ZY$n}wUlvA5=h^q2D~DX(QAu^Wxfs{v;U#J3x_@Vd{T|< z=IHyhs2j6h+OS5_7*a$xgdB?57^v~vn-PQ9mBFI!?UPQ)=l0&N-EXuXH3XI73LP)4 z$2*3Q*Ib7>oI8ZsX?9{JG4K=9$-h>{+g86&L`1opa~TTn5>(h;eABpAHY*%4Tc~_7 zkDXlB!k}WKg)()hu4xo;3;F%wkut1BmjTtUqMj=rw6tdXxg@Z%L3>7vB zv=~T8@PdHF>GdA~o?hfZ+H|FLn7#UNuv)SKZq55U$n6DTN zP@juTifDBIqm}jgPDnT1Zab1iMeOjszH%kiRkrXVHtx^zHgAuFCmDwxMASZFckklK z96;H9hgNm*>ThcSw&jGXm)X2nl+L~*Fa&&ZUuc}HIxv!#P*OQJh zy9&oiNsFK489nzo==m*vV%U=HKt3Rn!ANbNZCvEOy88J-qN7*JqOb$tv^ZXqQo{n+ zVxt;24(2&PtEr;yoZZ`;G0Skejy9XfNbsA$NESO27Za=!JJ~u)`o(~HkyQ)DC_2&*7Dejtm$Ov>s(`d`?7z1hcBs7Y+`MdKSFzR)>m&6gca5n{42b=Q+_cw3F`nu0fmZn<;CJglSr)@V0r>p?P!5}F!WSTgAMb{|3J%8fN7N;Dsal<* zSYo#CDMrQ{sA<$E0QUiB zhl^@it`HNI#$}^eWXQRxJ0_uF1;F+A=&01=Ff9UUFp zdwa2#-tgzyA8+&g9`Iy-$<7Y--v_&AWnmFN`kR56SB}E3mGm$&&*k2Gw%O%~dR%S@ zFtlnm0s0wh6Lq!U^B|Gy=-_(q8J_Mn-+rlTz@z!h{!ez#M4d0dRUHkDO>jYA;$Y|= zFYrlxj+U|@I$<$!apfh+3S|li3J$n?Wyi%0;*g1JK(kKp*DLbx+b6Xf^|b+~M~f-G zfNdYbmuRxXx9523Ql~r6C|Cdv>$z@RJ6>qfithdS^XEIMGcUlo5HB%U`*yeZcAvUb z9Y9*Zm}{ezp7wDRNP6$ssow;nYE zI3DWuEz>1}{;Z4=m7odD8!NOtS?AjZ``h)=@FU7uAGb9Df#Cd9_FDQY=66<1bP z_ARL>E5rTZl^E849~%qH`J^(Um7@dx{NMim#%N^`cZZ0GNHm+OA#}~;-Q(EZ-L-6b7#$sL^u_ukT%m!Qnwnr@=k9X<1r#chg+olt z2r#gZ&!fAL7*X4zs{Ia2;UL3>=D=~=6c2=RFlPR+ifB~2-FPj89&K6OoMq=g#U&PP zi}nZ+E%zQM{u&mRs*@N(5vydbgS%f^J2u@W)fF)qc;S%uX$OBc9^ z1q-VK?d(?39NG)rsT=e2EeU$U5w;UwMXp+D&vebdspI&kDe}}wwb)kamU*iRR#9wSG}sRqE(mQbgfvK zZBLi7&P9xjjad{+SrpZ>`s@Uy?wtIqsbIv^RieR>(0eRsaNOS9gak{xdt1EQ*%%Ki zh+Ip@`0K-2>H_)kLRGFhQ>cI0!#rv;va5b}s)kfcLGd)tp1JPhU)i<4G`>q)Zb+{4 zyzzlC4ms{wqOL8&>b7)r=nvQzDmhDqn|iOadz>(_IM$QuO07lrW3@4HyeLrIhgtG z!BDvJ)Q{oXzNIIJQ)%T+%wc$N)5Lh{V$Vc{^{sYE*hNNq6Z0vZj9?<*Lb#s zd|p33E-oZ`XyE{(S>cj@Q1q`;((sx7S#lHaj{6zUk%|8T)3D!R6^>+tf5Q{+uC1-z z7B`y_&eX}`^6lltl(#hsXc39Lq|dag-{t=AxWA?VNz;2wgAI2?uQKI@t-yD*=h?~0 ztRI5SRCt!t6scesdX1DAesB4n%+xm%BRt#9F+GY2)z-Vp?!{u&N$oj0bXbkL&~(lA zdC%GL)UlQ<2qy(``ml>BDPJTv{$ED(ztf7i(?A0YwYWuq1P2+a;!ZCQz?EE#CP6XV zxZ`zmKAF4&#UsypEW*?wAk=MBv-jjzxrUP5d$pvR{mQl1oR@CVjLb$nai+sUQi=-A zrBQ7im2Z+x21|nTd*3kyqr|7ZwRGkhJAZ?zj+;;*!VwnpR6;~1lk9G)8NV56U!IYh zZj62vzkvAdKebWwc*fMoXVfk6+xza`r9363kas2%zD%lxued*uJBen#<>I|WrLT2mi&5(WWMKM<^gE}2vp*qFjg5NsI8(9}e$zu;Hb{$HE)7j7 zi;HD(MVJ(YWq@xPBN{<&TIEw3MOZ3Q&jCfH=j5QS7eL0F0LLh8F2RE9a*l9xqtzd^|2tc>h-z=q21RuGZ~a; zA&rbcXkNg5JqpbYa@VYD4qIAEYwHutn7tzVkM~^zJ3tV z5!LJ(=~0cseYjf_%+cq`YfKF`;LNKyi!xT}3fyRTLPt8>m)4&x^6bXMgrQ7}rCEvm z&ar#dE`ZMaYJCaMx`Eu{FdSaqsr3D`U;2_wO`Gub?C z`BsO$-)izYVh9~&rI+(_c>A-%TA;cr9}&r)UEejPDwBfC&U6n0k+h!8x^=R?>~HI6 zn=J)^k~UfoFOt8^wOvi45|=orl(PBy^vS}f5`jD8#BB>91Ur$I?A(+xXco;do0;R@ z3t55V{@$~LFO1_Z&wFDx=IM-O-$>tUwKP+@rS9|bg&s*qg9!c;d)o-{=`~Ulv5@lD zM@u^^B4=0($Yy7apgNf&o1%tYremP<-am!JjRP)<2%}0i($}bczmpzn4g;eFK8d6L zEHb~c1q*ATeb#Npon7kn4eEi9^VHb)sF?3$M3P5ZE^p-Z)e=VCGq?$oE?;SI!q<~n z6m^A{HPKVkSQKSh$<>dzmG1ux(o5+~I2fhh9P8^;Ahs{;$L^NV(J`NN2u$+2UTdG1 z7B#5N+0LY>6P3;0?(uz1?2%i8dZklaK*)>oMb|j3EgT#Y_hv+l%zoQuwVTV`b zTK4=y%6v(8iHQV;*8cQ0lKrJPnBBNy#l@2O;dpX!WtL-TZE7tCD~*oe@Gw?SbpQv2 zbQOp(uJ+eGV|my?Ci{0n(-cWVzrrP0Fs`0iTrtw^j<%$|v_z0})|`{S5uli|>vva3 zXkzITo|PiGa8~q1!aqfec+!O*xaXVcsiRRrhDdp%g8aeA8YMGqYS;U>{&On|D8<2v znx}ZOXCzN_K1Zn#1~pPdMCRR~h*jI%^H5gt54$eS;Z7z8IFRaQ?58v+`*RQ6m`H#Dk5BzpLeyIJ-Kr4?~s2g9JT_Yt# zIaas znTw@uwj`3M74q}{{(94tfnxfsRK+{u4IM+B%eM%j&l@(E)^0CV_o=I7Px6aOM_C|> zRF()iD-n|x!g;sjbIN2u!qL*=@sBq)Qae4=j(*W}V=Hld6V@&Z#%Jm8qPkh7x}0&P zy*%#ubv;?I%dj7ARh%r`SpIw+=jhJAjeyUGv!k(f6LW*fS8i0$6c`STt_LrsK0Q3l z@9eJoS`hfIaZs`bx3Od^AU-poM{yc^59fODIxp~o1`WN3$y&KuIsf72Md1lnvZ6Gy z#w9n&TW&v7{Ewy1XF5pjjl5}z(^-_#=bVJ^^*FSbqlBccKZt#Ls!=pQSx!hSkloJ| z#-#M+ac$)F8;lc&4`!+_;7u0VZvTFv>YehF|Ip4`XWJ?u9aocL;zGUkN6l?DiWJ(B zF$dSrmQ>nGr~Y3!j2O=BM~AhaY(K~Oa-&;HkFQyRwE!h;d}iW1{IntI@#$eD!>Nyg zFr7}pOoSudSO3Mss$edO_3hLnzD0xFpP{H|DK+UWFj-* zz5UP0?6^_R?CCS|{EJM%>kKk7nWCyD#?>zK-&>R+>4$2T95-drJLSgK?jc}(D=eXnd=uE)Oe9-%15&M z{Avx$l8urM=d}(dMf4(J>oayOrQsTB1Sush^|K$L!NYE>SM{;dberA18nCjv8RV3O zn%ZKL4Q+fu{6O6!8TBJYjl*fo-L+YX+)Gy2N%Yn+`5; zUR1VQ!T3up&alDE-%BU^JgpQ;96=zVOuB{+RYYrLp)#u_O*ku`DWKWSFF=x_*R!&_ zHc|oKLBo?ti7TUKhK^J$P4_KW{NYGonu zlI93{K`p{THCu{;H|Q8Ux~ZWgQ?QVLN4QQK^}U3|2$KpnaJD;enO?XbYq`gli1UjY zyL5C7!0mgAX`8o&Hk24TKk}h-0SJMuvP|}nk5gqo?p))zGhqCI?`r0-?Ef#~T}ZN~ zoP;SaV!l9U82RCY*HSMla$iyL-}ZKyJ~!gD@GBmqgS7YCKu;G4NZv(F7^++l z?-$_oz6%Qj$*0pD33_OA{TeUj_{tvUWY1C4i(+n9c=_@r2q7;EJEeUy*HDYJjGt{s zSy?hP^t7S1WI+XJ89;={XE^^(lR=nJo*f&IZCB=hwnd;J&@HiVtf@)iTMN(?>T@cj zx3^c+W7R7ok3Ewy;NYR3-w9AfGp}hp%81NpX(V>#htUhV+5bFE&p_0d|C!Q1D zPk@)e&og&#r}}LvK77b5#v3*h#j5O;f|1%m#UFG0#LBAv-;b*{b(kwaUvm79k}D+f83Dl2SpQ=lL5X^1n+qaXir8@-tqkdys}a|wXMZ5BS$FKI?99aG=G$3hYX?U#Z2 z0ab?ES8Ul43);&Pqq?15WdL-F>guVt9?LfboPuPH7!ge55fFg-gNuWsZFKZFT3r$} zPHye{Yojz*uZnqXPT8FLgPHvpd-nOw9ovBu5c2*HEdv}gT48?`xlH^DkO_#VdZwM+ z)<*7daoGZ)E__g34Sfl{yaKZZN#sx7gYKrL+1^ZL(1cd#bBiVy0^I~a+XZYqOD)+( zfYx(;tnFGeh@Jq!1Phc&flIF?yA>8hW)V=D#`EdEoiO5nM0nxx}ehcSry1#K!%n7{3NUoyRpi7mOWFv8zWivPCVg!jNRMd z%VCAH-YHt5{zq=f>&v8;N@9{1@Po;L?Du*`S{h_(85HQvqnMkQ++Kx*K)Yry6jZ)s zQYXbwN&+5!r0E0_$#;@Rf8o43?g~MoCss7$XG|Upjiqw?(ShA5xDOpmWd*;D9P4%) z&QJLyCF?;f=fp#EwpKZbPfCYE;5K0+F-p2>W?UViU}1(jI&*7l^pup?2p%4uVF5L= z<5S14144azgjUe*=fXu6R@Rc|m?MSN*wK+E7fzia|MqA)gi#+?)OU|tOos42z$R$1 zObJJ5aiU~F?mYq2Sb1}PBow32Zv*K`LjXIes#-n)Lj$>Q8GQO9M~wJ7 zyhYZL(Dp6Jqa$2c*x2xq0HAiD)<+uQSSfQzz{em|moBxH(osfsL35t{r9f;wztP{VVMK)ne^KxK~^UlQtds( zeRe=b>?*rqwR2e0<$@f)9#0 zzA9xVfH)E0EoPAl4Bo+auv%JM_hIK%69r33N@(9M3}AP|yI;R{O}{M}`VeE~wsGSQ z8*nJVR*^;v{+8xuVt~8k!N4eAhJ=)r9)h~qzz2SS00F=UZ<5Zw2W-%Ld*Q!91h0RO zpz8!yKLMOo!e>V#g_4nx@#4j%ZH!WJN-XVSrYaF){&W+*`N1dqsuSR@yH#`O8P#0vXu3=RjUo&HGII6mvKwlJxJzP-bz=W6DHAr-I> ztnwBU^R~>*EiL28x)%Xo_@d+9X)F2r`(sTtKEo+oS=3Nk^AM{dF8@0E#c}dX?FJTX zy2mAMypQOe3q(ZvFslN0d~LXl0PNBK&J8jUR=VdfK}9*d*$%M=^*jx5C1*c)S5RrD z0uvW(qlTZc*OAe#XAgrO>1tnAcfLWH3Y#_@9e#d(5R%TXtQ?=WpN( z8g~EwRo~@MfSon4?83?ZW36ToE2oLJ6k%eLj`{@F45zrG^iBAqYrz_8?ygOT1}_mL z1wsOJ%$F4@p{<|b#3lB;uo4ci=i zBn`Ps)j}J<4k^snFq#+Fga0(voKuI-H^08vKC8p23!WdOt+YZ8P(FanoBHC#KJb`R zQ&Y=oZNU7m3RrJ~tUXrPRA1lg?{DUeo+04W&_Y^P!-t%vlE?Rj zbPNP9`8SvU{k;Cp4|Lz4uApxns=#$!_$7il!gKTUU>ZDN%I$rPRMgal1p|}>V2a<& zhe|D#A!>qS34rk+8y`_v^M@@P7`-`xHX+8VpP%jcl#ee`_>EoIytgEq?9h+ z&85JiIO9Dr?SwSKDeCwdMxH)bZB^zC(TlJCAbrJDKu%81r_S`XN-1|4&uVyhxblXv znXz$xk2U&q{StL}OJ}Dld2lbqI>dONe=RgHqTq@^Bh4LLB@D@&Qo{{iK|jEWWO)kd z22tdCVk-)RVb`O<#m3n3B~76y2JVS2g;j1YdgSouc;eJ#X!*o*4gT9~jHwiYOz zg}e4vXu`?UlZ=GKq}(QED|CCYdw*lntZ1-dyGsmaQg-&0)YTn-r(yZhQ&Hg#R;U5T zFnl&EEJLo}_5o6XuaK!?8qCn+wfglzhQnQ$4&Xj`u+$wNiSF&Twl=G>L#vBDAap)B z@ERB+LmszASRMu_DJiL_fKr&8ACkcI^mIy^AFZu06ST878g~ab8C)|&0{kH8Ro7Nm zD@N@=gx=klJcj^I3Y$T!@9uJo!xosve{&jct3xGmW&z%yA3@9o z6Co6c!r24W0Q4Za4N5<`^_=xTF!MiNtDK#gfzvAqh5$l?xI$@=rVQIL5)KVYSeUfF$dt(^MnjMK%p30|!%MW9|-@cMKBHqx)DTn}&wCrri zd%UWi+^qrqW?rhkZv zii(YmO-SGjyO!w&0kB=&z}VRD!hi7w4V_7dVOY{Idk_$C23xZT)Ei7p zOkmn>pzX)B&s`VQfDC~?h>wqNppcT5mMF&70t7+g-5bMYA-XpIqc^#pNk6&dh8@I& z%3m}6R`EqAM+YXVxFc`UwTdKn!LVEvn>kcFph|G{^dtxhVet8BCx}taKvrG3xeJBX zf~AE8v|gv*D=-j}kXXLmtgo-X6!(xL)26%z2A)7m^2pLOc66dEr4cUWxI6GHqsJN! z1c=hMAeS#yGeJVb!*P(e^4e;|<{X$w1B_UK^(N*kl7tl-dGReWCZ@&ItjY-sf$72K zu#jUZY>+2J6_~CwV_>Q$UH=}t+6y)VgsaX(n`$ww|;C#5C>2V3>5$%TgA-nJB z2&d22QD@8ibmXtw!2=2d;vJ#lbqsNayg7C>L6XQzZT~Ftt)!jtLp@I{1 z7|Vt!p+zj?DyL_vzb!>J{^V$z1su+Uo}~chANcuQ{>{MMYJtQDsu5|hWXOW#Bj~(m znqPzChXE9fABLJcqsMNnssvI&$do&~x-hR^q@+ZcL`4%wNRA+Cfac{3Qii&|@+2@VrH{|K#FXB{GQB zy{a$Y0%zXas#Vl6BnmL z^p%xC0D>nI`tBVW5z(XD&tNX^401lmKLwxm;v#S$BN(N46}6z13-t=ZBp|Q@3LG*N z>i+%v*6^a6+m|gge1V5?rW-6U5*n0z1F1zZUuNG@z8;ry5LZTz*WAzdFntGZLeO?# zb)fzZ1ab-(&$+G;@#ghw5$`Pv=+?fY7aM>vOPD5={RvVmLT2cwLpN>7EiC;*Ow1(; zifI_z3aMA~hI$;*wyYga*ctk`_wSuIr@q1H9dy=R=(754RN1sM6M`%ZNQsFpn(hu$ zfXDTzGb||V!krY2ZVSFp!wfBa^P69+dt|fvN}hgAXFobR((@3EUTIW%)lZhooUA+b z{pAI)ZBNll7#*g$Np)eoVCK$5Qr8GXf3aj-NY+NDb=dwd=-}>(6|{J`JVXqG!y27= z!iW=H)~`2j!^1l|J6kS8Cbb_F=84wX=VvsG zj78jKCotq_&I+4lQ$&Of!WWPZfW!OJxAY0k#=-(8&>%Kh?FmHH^>9qX^+8LAjJ6UK z3quY!B_$>86mBZzfR=BR9=e#2&O)u(!Ed|9Q4%kG=ko=4=}yx;9b~T&H8Bs<^!G5r zW$8?8YbVe6h>&Q^~I> z_V#2b?5_^Hg1iW9b#VwbKgSl-$OVZ~f3i5(5NAv-R`D1nAAq~O z2c`!w3X2KaokUr$uDt!heD@U;f=}R-j#fHKd2Oa~cNAKl!;XTJ@mPRa=py54E!}s# zGE4tF+~LCTE>M%pf^0|Y5#d*B`vSG7(&O$Li{F2MCQGwe$#P1o5f~1{>_9r=?9noe3yCARs`AwzRAaey!i; zj_-<-@X!hjo&^O3QC8Nt%!ff=RxOL`-yO(VQ&Ur!P?-%QaO`2yl9H0ru+m|zhCh`jXl^#0Ke&@vo{R0rfgO9CUCwS@Deatp7_3kd?)xliQX}2T;qfK#!vV zF4pD8n=43UQqovMU?96K=;ng8j$lap6!KlbG~uD4NDz2EP%R=LcwHSVGW$>eaR_PM z5G$y`j$pbHGJ3kpmoL-PCnqGdMVa>dqt)?f8Dbxt|mfvnQvgjJ~TS za5#E>$KODb_C1&oMi<$g{h)M)t$pm~=H~7h0*?%c$upNikSaKhkpHYtPiwLo1>J$6 zQV5+GdEGuEK(@{4CM+fq5sbJ8Q!T-iS`3Vj&OkmVo7gF6-WUXfcphi)QyiH93$Dw* zNkf~~`*{P-fa8on0|PUE3)?c3}-Dc%kj8JTY&nUOr4yNKM3 zR{?Bw1h3!I1nK1-+&7~d{Kg2hHEk5j61+dG=z4H=|1z&hO|d*Dn8YOz0@JnHk5%=! zn;neq!!#;{I+*p>n}E@3#Co#+6v}Rp=Yo_5@Qj)Fd<^9Sj5rKj5Afa#H;srrQ zyD2^ir6>d{EhT$FBhZpTz=CPnrfK%=a4H&_^TWkd;B0nN8nAx`_W0mBUS1VkVsAMI zQb>S@r#r2WabYA3;xK2}a_iPD;6awBSWquO7JXf8x^92s0wJLRD0M_MLo#o0aa};f zB_v{IVPf)?ZCTtfMV%=b0?#{OizYP2ejOJ_gIs?5BR@YM0(#9v(rgukufy%dtHk)S zKcVo8nZSFHkM|VxtAMPKo9@ehAb;|MQo-)r0)j(F%;0Br6x2Xe9pkx?rbVgZ-dm9V zE4-z3-&-|6A}tSr4}xGR|M5U;D*BX>(!_yC^#z<$2OU2attU$?utnQc)D!-)j#Mg(Of6q$HP#H^{$QYMJU%hVq~dW1-4^~Jab#x|6h zv;i&_7Z)eSziVh{XllxUfOOdStJnKvc!q-_Vm(&ntf4`UKy4RFKUw_nu9sJpS{6Ln zU>pKd%(rA<3Aq-&acWtp~gasXBVX;fF#jiDjL<1|9@6>}XDD6q|c5FWv7p5t8r8k?n;XTY5NG9A2NX(cl= zvtfatsHhLor9WK)O}g!l1D6fg7Ryu@)hg`Qdf&XTDk#nd`sN1-0*mR;~ZSnu$>?) za!*0wD)B2LUnh`&fy~VI7^91af_m0X^!oMd!&CksLrX~k^4hiS=i*YDm31T0D+%m_ zhGy`qx9dQzE7XJV+z5kM5Qcp-+=HYH9(I7qUBTiZ22~55f#pK&>q)itZz5JJzE3D152ucXf z7$r<=4HQD8A+tctMFP*hGNJ4i;ySIg^+M_aEMgvd@csMxuS$-%%F2C6MBq@kW?h7a zjps{#9BJ>hOTm%Aiw0(7Q13DPwuB2$hTYZtm&r9LDaOh?h__q1zs>0{U9v129uKIe7kGD~d$qpMMJm zE%x{JJcQ|nsDUi)zakpIesrs`v$Fu|!T%7(Ur>ZX#cWSz@HhvP^2~D%O})fJDk3l> z7mO@H;3#cfPd%F%Hd%! zFe1l@z1y3b?n0LA>wA27INcrxpWv%&BUrSuYL>D%EJ_X;is!Zk^tfc_^K?kN9hF1z%}d0EDwew{*rrq?8hY2f|8?1a_i;XnrwH1|c*c-e2|A46zYm4v0O zhbk4|>EOULIGjNR4px{@f{_lWKKm7xP<$ngNh{AXrF z`2x7PIsih5bf17~-`ryyivs4Y_7Y>>8&d7sxbek89`GHJek7d%fR*+Aef`*i{CrqjPJX_dhez!Fv*>dud_302h6e^>?Lc{LWMq`G z2mu#rUcC)1@Zs=008&!Y*wZn3-TdRn zz`#J1WeFy2g+39~q^&+T;nD|2YVgwZ^48YYbU7~#pNvkPRd}o!B9OJ%90nBdo^%*rQ2T4T6`_UQLZ?6jb2R zptbd*dX~<9^=e|QzcDP)XG9e7^6>#Aob~xLHqz=0%qB(1kpTf|Af1@U8^5R0ym5>3 z^9Ct$nmX3j>_~7(2#^(QY0x6yyNBEA3>FMc7MO1+C@eIEn+`J&$P`RGKuY9Omtr{# zVJs*JL9+Amwzs!OYdswy5}KGKzj%QeazWC*d~t)&)&P?Mh6TL5yg>b}Ggih*xP9SK zKR}H5j4HU{fivmZzybi^rLJ0Dwt`ls_Ggu+{aGR8vOIq|-eBdM)*XRR8Yd_zi4K8& zoDlMc_RcaS@Z24}0(CHGB;vZzcsddIteCTWQwY;Rg3RGl-x3RD8*Cw18y3PDV*@aO z=%p+RrQE%HFaiz`KL){&G$gXk0@|_Dz(fC!(!M;N%JhH#K9uATNs>%sOW9?KB&j5O zNOn>k*+(g|Bxi%t&);4mKGr?%lCSGKEJxcfLJgva@>DXd7M!lr+9_-4Q>Kt}S!0jU#rt$Oh<0rY@VOO1Gg?PyG;gizQcF*`$ z8@so=DJQehlPH$y8HkE4zRa)bD&{6+A9Z&h!VXCgs_Z$9`h}xNHE{dkzqu)#EFFml zvr$F{WR1=@PGl(*Tu)IFx4PiZ|E2pB4n&Ih2h9P7ljtc3V^S*&+&aVW@ZsxDO17JX zV(U2!Di~YRXyPg=0#p))n`(-wjERwv#GDRlcJ3)|0*t+cK#oE$PMEp5(wFBE&}E*g zZX}ZcsH(MXTUxrEqa(#;hbXT#l)j8!IexP{H5(A|({khkD*;9Vh6Hs?cKZUF`%HPwqHS5 zjd%BNm}WRjUt?41w46YT0Ess>F@czltFSLX8U&N|=;ALJ{ez18G`)(@IvVY-8&hyd{G1$~5NU;>qF3ZVndCmXH@Qy| z4W)tI?}YuVVKJ_Irb_$6@2$hG01t|wp0VEQ(SbL11s$wGf3|HY0L#vuv__S_zP{X3 zO+H&jPininyGLIyDk{2JQew2{m4|v?eb_xfR1{KTf%1wJqzWG-CjsG;EUbGJP#7)x zajjsN3|j>csf*I-blK>U-fH<_Ey-qACL++5>+o@=u^4i4a)(7g@t$~NZ-RnrKL&=3RNii6&06_^8};x6@)cc5nX8KXyEtm-3t|9GjC&s0y6g+ zf%zjI_74n5-4A_P+WqPkCB%dM3+Y{p)fkxFMxH+rA%%UNqd(m3lV&{z+S#T%bk{21|9U&&AMK^ zh+CQjl4*NP>T-etiGv9vAJU_r9hH_j+u41G+`LswtI9|hf2O(|6}(Xe{US(%520k& za4Y3I;yHW0@BDsXvhBN59af~y3hVM1w1lw-voI+3YRLaNBvJ*9yJ9CNju3EB6`xvF zF3jjU`LC6T%yib`6OxGhf7iU|&tI`z|j#00y-bZ zKA@UjB0J}qiU|tF)@$Fsnz_arRDzDQ>O3Jf)ZbrzxG5T|AH@OlN*0?sOPznBPes=N zXSSoG_mLwV#xW~FzEcyi*kYb{if`Qb7XpN3@k`3)f$5pokU!2~{0K*m_SUV{2BlaE zMixSDThdOAb;oC9m|0pI8a9LDd1XS+%FmaPkvR(55`-?A2+Wcw+xxE|hC6pa6Vuex zBo@UlQU6P9Zd;iP@tjReJkU4^L?0R&3blb>M5Up(x7Xj+*Y_hDXCO*zTiY(bs5RC^7QsLDSC$y7B#hw#aDvZO{+pDXJ>D3Z=}*;4z9#vlY-i@(N4vU8#iK{ zm_6#PF52GjQ8iiMsjZ-(fG$QwrTkrAtL{%95MY$2pfoO8!0TF2=BdCya}yJjef#)% zd0*cTR;!v5VY-w}kB=k5PC5;ke{N{U!g7Lzh4?P7;NV)TLd*ap6<7#09ZxSV9!2pg zJ@{ZAfSps%)bt5;zGG_1L@7BODj4`;y0C(H0U$&@Iz14$m^1?TLSUB6FjS{4TehH- zrJ)zcCsr5bGoofq@k`enJS~xE+?uhi>sV|K_<;JEf6^drvY$OG?fMQ&7eb6&?gjd? zbvI#4?G$@F_32G0_-cd{VR#5RP--cr4Ra~w{g--t@;&dCmOk@P2L$VU@dD4q_P~L} zUym%>E9Lw(F04R!i6iZCfwTE=qhlOHd z*ws~h`BBDm4s7;y(bvCDOkm)en)ri;@&5hcv%AKHj?(x7h3a>!?*0px+pW<2T&lNuoS?=AifWzA*Rk?PrlIP@anFVfq=(UF4PdU|@yW8DydgD~PeKJa+Qdnp`6olO%u zkB?J9YZ-Zw=-f zm8A0ufu*CgTaMNV-o{y|;MhRqIrIlug>VdJ4#D}hf?(rGDP91LGp@6*<0SU@w$Jya zgku(^CxCCa9z0A;_U675|B}A-1mO2>1G?dZ{TPZp82V8oJG`W-N{CuB<&%54tM}!m zWA>MKdROJONS_V4jbZb%GrL+atO$J{r4pL#rDQPO#VMdT9UamYaGA&Y@1x9@;W zG%yHoG6pv=X+&Hwin0`(i&2CbVie6PiW3GSA>hF5RgGR#hgqf8d@xcVL`HpLk4n2& z|GLux`6wl>D&K$n=p3~T4ibB;A*PDHks)&K{+B%!4=WUgZvqqCzh@eV5#T%b?$J+t z!vN{li%@SjrU8Cp;55w6&d$;@9*LrLx_OxEkG>ALt+2{7wG_u!dmcODD;N1NdX-Pv zuv{VLZL>b^*rka*<((ZK%N_lGSsH!aaG0M=wb;N(#Wqt762DPcC{eQPkZWyJ#PSqV z1LK8dgM1{{$}?w;kNSCFhG%nr=8l+~ZCsmz|El^{A>zcAOn`#JtwAeh;YU0F|6}`q zMJNB)zXC+?@6->F&*u&i=D?krTF;|LkD8l5lgilDq0X-0gVSq2-Hk^g(!nKkIf2Iu zllLLk)awOEBrt+cDtwHLSg;46MGDFcABfe&K3xXJ+*3WZGiDiDcXMN7uamYmHk}|u zP;caO!TNyj#BQ(q@)!N>nV*WS6Vi2e?oiaUf9+{lg~7j-TseT}EqJ*g(OAHM zeMUyvo~c8@^78Vx>`X0%dQ2=|fG9>=L(?m<2@MJY4~~CU8JPR{_%KxM>;49EjTPDa z$1DW9Ei;FgU6!6ya@}rn-wJBc8gwJVNhKHY@aso?T)Xczui(o7K*>o6r9%PfUGA~xA!0{{aqY;hNC`QywjfzD8KXlc!IYGSOWE zi4Xqs4`!#2fho!LZhreV9I2Jy^nM_K8HkjnNcsWU!13d4vC|XTgJ4*9Z9y{{0H65A zjRi0Q3GPui8yQJia-Ez!l5vWiy)vZ_(p+}X$VpaSUTWAluL93Waf*PUYdRMGh{C}@ z{!fpl=8uky;5hO1^))jy3kV1}aNqzik8vLQBeaC)T*`)ieFK#aq+(EJfu|_BV!WYs z0@Z@3z%<^jr4>$DPGI(>!8p-y+{wek1I#s+nhx~%FFJlIBBUQn^+OBz5uD`%!S%Lf zE@-O6qX1)X)r3N6>NQY4x)uRjVf_1lTgGh?Fq7Xej!gATYdi82@40_ACdxpou zKlg2~n4xS!4VZ98yLaki@07?sWCl9KnHS)44P0C@jq~KN($IF5mY(4`Lg29NM3n|b zoSc%f)wO}3kdPoM-5M#JRuO#duihtEqROBOa^74S48w6RvNGR*u zz;j_$f!2r{z#iczSgUHHvE&P6yvviIBVj#WL)tq!I5|0;otzAR0Mzc0w8W0xF+a`^ z%$hlb%-zH8?p=R}GgL$XEnv$*pSstII`t(b@dH>?02UyATg!#L-$*ez!?H`55Zav| zoN>-+gWXoivF5^+ug?}=ADGCn;MBuI%b_(0%61^9!&{IajU?jke0|cjcW?Wo3h?mQk1&OPDASw;4)6|A~U>9yN4D%7_pgVW(ax|a^ zXwGL|-z0JNDnc6KDNX$o1G}jFZ+1!Z1G>?A_bd*_mpZ z4|)xe9!=;o{3~{ZhevNu$w7jIT@rZRB}|xz+db4rrhedGzQS)z`UeLC{(YVu9nF)# zg1=Rld;u!?_H6|ZPjqZ7a+&{wn2-rgBk(UQA>ea-RYdUdt3k`3i)cuSOeYOkS}?4D z(6Gnev9!!?^9ag>kA>1>P_Nh(gM)+M_#kyC&T)f5G1{I%`iF*w#>RB~`g9})p$cLy z2P7?aHg#T7xHhiKo5gK_6MuN<_=yva6$}<-9}?3N7mj)96@acPHa51dE)3Glg9kfm zvmzoQ0I-~t4nnI12|+A`0Eh}9OaTn;-TN5V8&CTG{rBpkq*H3K2|2_R#AsjQ`KQp_ zE<3Dt7?2Nmc@-5G7sFwP-&I`fee`HPf;{~F{Xt?%3r3}-y}~}mrg_;JdX-yZGU!G~ z*^A4n&-}-(4rqN%ezH<>RV3;tFIPlaS=p{^kB6^OU6S0=KAu>H1QYMA8{RTh$90rV z-<=8w*p3WcY&yIfsNm5IZ`G7d(YOUg$-f~{$Jh7T*bWc51iVKTx)+eEI5HgU?TzxR zKn)1Tyke&{L}HgT_oz}cZavMI8M=X>K#j^g=4T3K;KZAG0^M`r5DmV z3{Z19Ln1d*Df&)zA*WO-_9X$7$^Zy;YDt@HD!b`aw+z#DYs1PrD<{b#x_^wMl! z?TL-%%Q1iMzNb;{XUpXL2tdB0@f#;pcvp#v8rS=q>xxsE zY4(#^m0Oz*k7H9g>d4~~d)qlV^;U7-S73v04tQB~K!IIj4nl&vnulF6_d6{s>eDaL z*Fl;w7`fmU=tpA~)X0sSH({uF(Bs2JQu7_ZxR_CB1*D+pn>`&><^{#c-72s6_&sKv zBqN{acvP9cJ-sIG%d_M`_CcN2SKZwxSXI{lpil#v=ugM4RXcY`B-UU8IJiUo>G9X7 z<&>+1NhyFmA+eXRKU`q-aFI0YKsrc#Vqzj5CN9UJL}nPhB4|2EIu1mJ?VdeN_$mU~ zXl+DM$-Aw_{nsC0Z$s$8NQnX)%=ML_;3PMkAjX>hdh||T6mb@#<>}Kr3P*uAk2j2r z+aMSPJdvJ}L7^xpY^(o+qF%gFMD5c&{a8^(CcnJgqM)OvxA!*)05Gcom@i?DjCf_3 zPa_p0ZZDdJ=3PdxIpQFPDl!Ag61BoX>KYb~;Yi7Vk)`#-=udF`D!h-6=PvycsT_X5 z#wNd@;1X!in=-&C5F7x`O4~e07M>@+_N|*YO_h>CR9a<&1xF^yaA(On@)`sr*n1Gh zBHAZ!R99;d$nI?j&&tYz={zHBX?!r zIOQGjkjkD;)AZ|o^@_6rhj!hWxsY>*P)d4y0#BS6MWbD}AAQ=<&HwN2aSBlO% z-ad1`qyM5}@6Uld@U3F-Lp1ggnrpI2 zNvU&miW~!e!hi*c?FG2f{cF>!EFo*4-ipEZ)inw%anHSVy+7srn;h#|MUVjv_t0-| z>&@D2OB_}s=}YAaY#{Dy^&nI9K!OZkBRM*BDL_Eyu8m)FHG02fs zRD9an`XFF5X%T0^tSqo@9ed(F6_STwxQSW7R$BkU{O5({v`35ZDO0GLE8_QoJI|j# z2bG^W1SX4?m79wmyu^l1aWo%Vpg@A2{t!%s-(vzzA_^`?;;NcrC?zE?ACvp}?Sr?- zX&;&ar{cfnGDQCMYt~F>JyGN3B=CVtq@kJMVuArFwY$YZ092k#oHjCA9{e|$%N?m)q6`ZCoy{1Cu8qd7HW5!cZ<4KI^H&aJl51j(zj8#)32ppfK;tpbpObp=RtQab}Sn5`6MBw5yS zd|aG_a;+$6s4A*$&;E(ipjwzjsv*lq7FrOTTMBHgLH zwtY5_xR{tc*q5Suc8s#*Ahv7mITrXv=Ire0?CFIxzxrSpWPcs7hbm>_`CU;tN~KR; z?BPPNM{>HjG|$kEZ>|(E(dMBbE4vjZ7?!k|ps$Zlq+}3~H__AG-Z`1q%-aT6R>X&L zz*{ufjO4gO?Y6s9Rg;D;@8pKSt(pRe7J0Ir7s&4iJg%(GU!g&-W_?a}30Y3692D1{7cgD`2?B=p^Sf*OsGxG$Tni1tdO7j^Mso8>K<21HRj_$M zV&Ww13db6ss6YS2$6H3M4;;`qzGi;@^g1G5avt}zGwrsNU{quW%i{N#7-1fH)}AM2i6aob0H6NYM$-H~_F6yS@!TQ&2ZMUkbtbebJ(@HzRlwbhlC*rt8>A7k&|Uj`({TF+K9rXm3Wjur8x0r;5Uk-tq(KREV| zg8=YR#(GdPxZMcs18)ED9SR;L`1kb8)DC6-#Ec0>Kd;SZV=N%yMi)m`}KuJg^#BTxVGxzgq#r2Ur0JS37iboQb+ga8;Ifajq@ z3T{!@%)8sty7AIl?z-KVqKAUd9TKhAHRxL%bb*d>D|S&~A+n-SW1(6~#v4it!dYQB z{$jFR#Vxz3($Hpq&{Zi#lW@%LKEu?is_yIZNxz`-{&8bWcx))KLg)K+fwYx;Z>7ZN)SiT)^1wh3 z7BM&BNaTrief4dB+ejv`3~gSGUBpTBP4r{bEwPHcP0`k!AgfAClh}iof9gxd>uZ&^ zair$b!{42GbM5-3D19qSOGtHQGONts=AfH!#e2}E>2n6lL@0@=<{W*(?ptfZ9}ciF zF6N=e1zM6VzQbe*XNu{gikoiU{X6+MkQ*s6*)94pea#s= zE&?8g#~n8l2g2dwb%&V$99&X4dnniwX!y!kM#2)pZhz2o2Kx1`OC>IC{NYMBn!x@l z$b2NKA0JozSSS$IYjhE3Uq5IFr-V*xf(bqJZ2LhETEjJy1%pF|7*9RlHTl|vEh7G; zLt?Yu)oEGb7*^j8T2>p|o{O|E)i7HE*>Z&XDys8c`gkI9%VKI%57Yk= z|4qtS&rF>o(u{4ubTe*j|Ml@1K;G(?>-f%CY}4EHL-9OqeV~8{F_cyQTbRAlon5*R zGBFZqB5fnxue*g)M7D!#K10^4;7oT68+jm!Ia(de@-*2r)HY+pjc9 z?s)F1cc<;{LipWdsbQY%3W1?C&grfdUXv&B*%gajhWTTka)-n%Ec-0_$nw~02_b@) N(H?XC(p|29{0}A!x0(O| literal 0 HcmV?d00001 diff --git a/graphcast/dpm_solver_plus_plus_2s.py b/graphcast/dpm_solver_plus_plus_2s.py new file mode 100644 index 00000000..b3685618 --- /dev/null +++ b/graphcast/dpm_solver_plus_plus_2s.py @@ -0,0 +1,187 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DPM-Solver++ 2S sampler from https://arxiv.org/abs/2211.01095.""" + +from typing import Optional + +from graphcast import casting +from graphcast import denoisers_base +from graphcast import samplers_base as base +from graphcast import samplers_utils as utils +from graphcast import xarray_jax +import haiku as hk +import jax.numpy as jnp +import xarray + + +class Sampler(base.Sampler): + """Sampling using DPM-Solver++ 2S from [1]. + + This is combined with optional stochastic churn as described in [2]. + + The '2S' terminology from [1] means that this is a second-order (2), + single-step (S) solver. Here 'single-step' here distinguishes it from + 'multi-step' methods where the results of function evaluations from previous + steps are reused in computing updates for subsequent steps. The solver still + uses multiple steps though. + + [1] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic + Models, https://arxiv.org/abs/2211.01095 + [2] Elucidating the Design Space of Diffusion-Based Generative Models, + https://arxiv.org/abs/2206.00364 + """ + + def __init__(self, + denoiser: denoisers_base.Denoiser, + max_noise_level: float, + min_noise_level: float, + num_noise_levels: int, + rho: float, + stochastic_churn_rate: float, + churn_min_noise_level: float, + churn_max_noise_level: float, + noise_level_inflation_factor: float + ): + """Initializes the sampler. + + Args: + denoiser: A Denoiser which predicts noise-free targets. + max_noise_level: The highest noise level used at the start of the + sequence of reverse diffusion steps. + min_noise_level: The lowest noise level used at the end of the sequence of + reverse diffusion steps. + num_noise_levels: Determines the number of noise levels used and hence the + number of reverse diffusion steps performed. + rho: Parameter affecting the spacing of noise steps. Higher values will + concentrate noise steps more around zero. + stochastic_churn_rate: S_churn from the paper. This controls the rate + at which noise is re-injected/'churned' during the sampling algorithm. + If this is set to zero then we are performing deterministic sampling + as described in Algorithm 1. + churn_min_noise_level: Minimum noise level at which stochastic churn + occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. + churn_max_noise_level: Maximum noise level at which stochastic churn + occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. + noise_level_inflation_factor: This can be used to set the actual amount of + noise injected higher than what the denoiser is told has been added. + The motivation is to compensate for a tendency of L2-trained denoisers + to remove slightly too much noise / blur too much. S_noise from the + paper. Only used if stochastic_churn_rate > 0. + """ + super().__init__(denoiser) + self._noise_levels = utils.noise_schedule( + max_noise_level, min_noise_level, num_noise_levels, rho) + self._stochastic_churn = stochastic_churn_rate > 0 + self._per_step_churn_rates = utils.stochastic_churn_rate_schedule( + self._noise_levels, stochastic_churn_rate, churn_min_noise_level, + churn_max_noise_level) + self._noise_level_inflation_factor = noise_level_inflation_factor + + def __call__( + self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + + dtype = casting.infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types + noise_levels = jnp.array(self._noise_levels).astype(dtype) + per_step_churn_rates = jnp.array(self._per_step_churn_rates).astype(dtype) + + def denoiser(noise_level: jnp.ndarray, x: xarray.Dataset) -> xarray.Dataset: + """Computes D(x, sigma, y).""" + bcast_noise_level = xarray_jax.DataArray( + jnp.tile(noise_level, x.sizes['batch']), dims=('batch',)) + # Estimate the expectation of the fully-denoised target x0, conditional on + # inputs/forcings, noisy targets and their noise level: + return self._denoiser( + inputs=inputs, + noisy_targets=x, + noise_levels=bcast_noise_level, + forcings=forcings) + + def body_fn(i: jnp.ndarray, x: xarray.Dataset) -> xarray.Dataset: + """One iteration of the sampling algorithm. + + Args: + i: Sampling iteration. + x: Noisy targets at iteration i, these will have noise level + self._noise_levels[i]. + + Returns: + Noisy targets at the next lowest noise level self._noise_levels[i+1]. + """ + def init_noise(template): + return noise_levels[0] * utils.spherical_white_noise_like(template) + + # Initialise the inputs if i == 0. + # This is done here to ensure both noise sampler calls can use the same + # spherical harmonic basis functions. While there may be a small compute + # cost the memory savings can be significant. + # TODO(dominicmasters): Figure out if we can merge the two noise sampler + # calls into one to avoid this hack. + maybe_init_noise = (i == 0).astype(noise_levels[0].dtype) + x = x + init_noise(x) * maybe_init_noise + + noise_level = noise_levels[i] + + if self._stochastic_churn: + # We increase the noise level of x a bit before taking it down again: + x, noise_level = utils.apply_stochastic_churn( + x, noise_level, + stochastic_churn_rate=per_step_churn_rates[i], + noise_level_inflation_factor=self._noise_level_inflation_factor) + + # Apply one step of the ODE solver to take x down to the next lowest + # noise level. + + # Note that the Elucidating paper's choice of sigma(t)=t and s(t)=1 + # (corresponding to alpha(t)=1 in the DPM paper) as well as the standard + # choice of r=1/2 (corresponding to a geometric mean for the s_i + # midpoints) greatly simplifies the update from the DPM-Solver++ paper. + # You need to do a bit of algebraic fiddling to arrive at the below after + # substituting these choices into DPMSolver++'s Algorithm 1. The simpler + # update we arrive at helps with intuition too. + + next_noise_level = noise_levels[i + 1] + # This is s_{i+1} from the paper. They don't explain how the s_i are + # chosen, but the default choice seems to be a geometric mean, which is + # equivalent to setting all the r_i = 1/2. + mid_noise_level = jnp.sqrt(noise_level * next_noise_level) + + mid_over_current = mid_noise_level / noise_level + x_denoised = denoiser(noise_level, x) + # This turns out to be a convex combination of current and denoised x, + # which isn't entirely apparent from the paper formulae: + x_mid = mid_over_current * x + (1 - mid_over_current) * x_denoised + + next_over_current = next_noise_level / noise_level + x_mid_denoised = denoiser(mid_noise_level, x_mid) # pytype: disable=wrong-arg-types + x_next = next_over_current * x + (1 - next_over_current) * x_mid_denoised + + # For the final step to noise level 0, we do an Euler update which + # corresponds to just returning the denoiser's prediction directly. + # + # In fact the behaviour above when next_noise_level == 0 is almost + # equivalent, except that it runs the denoiser a second time to denoise + # from noise level 0. The denoiser should just be the identity function in + # this case, but it hasn't necessarily been trained at noise level 0 so + # we avoid relying on this. + return utils.tree_where(next_noise_level == 0, x_denoised, x_next) + + # Init with zeros but apply additional noise at step 0 to initialise the + # state. + noise_init = xarray.zeros_like(targets_template) + return hk.fori_loop( + 0, len(noise_levels) - 1, body_fun=body_fn, init_val=noise_init) diff --git a/graphcast/gencast.py b/graphcast/gencast.py new file mode 100644 index 00000000..1454de76 --- /dev/null +++ b/graphcast/gencast.py @@ -0,0 +1,284 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Denoising diffusion models based on the framework of [1]. + +Throughout we will refer to notation and equations from [1]. + + [1] Elucidating the Design Space of Diffusion-Based Generative Models + Karras, Aittala, Aila and Laine, 2022 + https://arxiv.org/abs/2206.00364 +""" + +from typing import Any, Optional, Tuple + +import chex +from graphcast import casting +from graphcast import denoiser +from graphcast import dpm_solver_plus_plus_2s +from graphcast import graphcast +from graphcast import losses +from graphcast import predictor_base +from graphcast import samplers_utils +from graphcast import xarray_jax +import haiku as hk +import jax +import xarray + + +TARGET_SURFACE_VARS = ( + '2m_temperature', + 'mean_sea_level_pressure', + '10m_v_component_of_wind', + '10m_u_component_of_wind', # GenCast predicts in 12hr timesteps. + 'total_precipitation_12hr', + 'sea_surface_temperature', +) + +TARGET_SURFACE_NO_PRECIP_VARS = ( + '2m_temperature', + 'mean_sea_level_pressure', + '10m_v_component_of_wind', + '10m_u_component_of_wind', + 'sea_surface_temperature', +) + + +TASK = graphcast.TaskConfig( + input_variables=( + # GenCast doesn't take precipitation as input. + TARGET_SURFACE_NO_PRECIP_VARS + + graphcast.TARGET_ATMOSPHERIC_VARS + + graphcast.GENERATED_FORCING_VARS + + graphcast.STATIC_VARS + ), + target_variables=TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS, + # GenCast doesn't take incident solar radiation as a forcing. + forcing_variables=graphcast.GENERATED_FORCING_VARS, + pressure_levels=graphcast.PRESSURE_LEVELS_WEATHERBENCH_13, + # GenCast takes the current frame and the frame 12 hours prior. + input_duration='24h', +) + + +@chex.dataclass(frozen=True, eq=True) +class SamplerConfig: + """Configures the sampler used to draw samples from GenCast. + + max_noise_level: The highest noise level used at the start of the + sequence of reverse diffusion steps. + min_noise_level: The lowest noise level used at the end of the sequence of + reverse diffusion steps. + num_noise_levels: Determines the number of noise levels used and hence the + number of reverse diffusion steps performed. + rho: Parameter affecting the spacing of noise steps. Higher values will + concentrate noise steps more around zero. + stochastic_churn_rate: S_churn from the paper. This controls the rate + at which noise is re-injected/'churned' during the sampling algorithm. + If this is set to zero then we are performing deterministic sampling + as described in Algorithm 1. + churn_max_noise_level: Maximum noise level at which stochastic churn + occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. + churn_min_noise_level: Minimum noise level at which stochastic churn + occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. + noise_level_inflation_factor: This can be used to set the actual amount of + noise injected higher than what the denoiser is told has been added. + The motivation is to compensate for a tendency of L2-trained denoisers + to remove slightly too much noise / blur too much. S_noise from the + paper. Only used if stochastic_churn_rate > 0. + """ + max_noise_level: float = 80. + min_noise_level: float = 0.03 + num_noise_levels: int = 20 + rho: float = 7. + # Stochastic sampler settings. + stochastic_churn_rate: float = 2.5 + churn_min_noise_level: float = 0.75 + churn_max_noise_level: float = float('inf') + noise_level_inflation_factor: float = 1.05 + + +@chex.dataclass(frozen=True, eq=True) +class NoiseConfig: + training_noise_level_rho: float = 7.0 + training_max_noise_level: float = 88.0 + training_min_noise_level: float = 0.02 + + +@chex.dataclass(frozen=True, eq=True) +class CheckPoint: + description: str + license: str + params: dict[str, Any] + task_config: graphcast.TaskConfig + denoiser_architecture_config: denoiser.DenoiserArchitectureConfig + sampler_config: SamplerConfig + noise_config: NoiseConfig + noise_encoder_config: denoiser.NoiseEncoderConfig + + +class GenCast(predictor_base.Predictor): + """Predictor for a denoising diffusion model following the framework of [1]. + + [1] Elucidating the Design Space of Diffusion-Based Generative Models + Karras, Aittala, Aila and Laine, 2022 + https://arxiv.org/abs/2206.00364 + + Unlike the paper, we have a conditional model and our denoising function + conditions on previous timesteps. + + As the paper demonstrates, the sampling algorithm can be varied independently + of the denoising model and its training procedure, and it is separately + configurable here. + """ + + def __init__( + self, + task_config: graphcast.TaskConfig, + denoiser_architecture_config: denoiser.DenoiserArchitectureConfig, + sampler_config: Optional[SamplerConfig] = None, + noise_config: Optional[NoiseConfig] = None, + noise_encoder_config: Optional[denoiser.NoiseEncoderConfig] = None, + ): + """Constructs GenCast.""" + # Output size depends on number of variables being predicted. + num_surface_vars = len( + set(task_config.target_variables) + - set(graphcast.ALL_ATMOSPHERIC_VARS) + ) + num_atmospheric_vars = len( + set(task_config.target_variables) + & set(graphcast.ALL_ATMOSPHERIC_VARS) + ) + num_outputs = ( + num_surface_vars + + len(task_config.pressure_levels) * num_atmospheric_vars + ) + denoiser_architecture_config.node_output_size = num_outputs + self._denoiser = denoiser.Denoiser( + noise_encoder_config, + denoiser_architecture_config, + ) + self._sampler_config = sampler_config + # Singleton to avoid re-initializing the sampler for each inference call. + self._sampler = None + self._noise_config = noise_config + + def _c_in(self, noise_scale: xarray.DataArray) -> xarray.DataArray: + """Scaling applied to the noisy targets input to the underlying network.""" + return (noise_scale**2 + 1)**-0.5 + + def _c_out(self, noise_scale: xarray.DataArray) -> xarray.DataArray: + """Scaling applied to the underlying network's raw outputs.""" + return noise_scale * (noise_scale**2 + 1)**-0.5 + + def _c_skip(self, noise_scale: xarray.DataArray) -> xarray.DataArray: + """Scaling applied to the skip connection.""" + return 1 / (noise_scale**2 + 1) + + def _loss_weighting(self, noise_scale: xarray.DataArray) -> xarray.DataArray: + r"""The loss weighting \lambda(\sigma) from the paper.""" + return self._c_out(noise_scale) ** -2 + + def _preconditioned_denoiser( + self, + inputs: xarray.Dataset, + noisy_targets: xarray.Dataset, + noise_levels: xarray.DataArray, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + """The preconditioned denoising function D from the paper (Eqn 7).""" + raw_predictions = self._denoiser( + inputs=inputs, + noisy_targets=noisy_targets * self._c_in(noise_levels), + noise_levels=noise_levels, + forcings=forcings, + **kwargs) + return (raw_predictions * self._c_out(noise_levels) + + noisy_targets * self._c_skip(noise_levels)) + + def loss_and_predictions( + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + ) -> Tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]: + return self.loss(inputs, targets, forcings), self(inputs, targets, forcings) + + def loss(self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + ) -> predictor_base.LossAndDiagnostics: + + if self._noise_config is None: + raise ValueError('Noise config must be specified to train GenCast.') + + # Sample noise levels: + dtype = casting.infer_floating_dtype(targets) # pytype: disable=wrong-arg-types + key = hk.next_rng_key() + batch_size = inputs.sizes['batch'] + noise_levels = xarray_jax.DataArray( + data=samplers_utils.rho_inverse_cdf( + min_value=self._noise_config.training_min_noise_level, + max_value=self._noise_config.training_max_noise_level, + rho=self._noise_config.training_noise_level_rho, + cdf=jax.random.uniform(key, shape=(batch_size,), dtype=dtype)), + dims=('batch',)) + + # Sample noise and apply it to targets: + noise = ( + samplers_utils.spherical_white_noise_like(targets) * noise_levels + ) + noisy_targets = targets + noise + + denoised_predictions = self._preconditioned_denoiser( + inputs, noisy_targets, noise_levels, forcings) + + loss, diagnostics = losses.weighted_mse_per_level( + denoised_predictions, + targets, + # Weights are same as we used for GraphCast. + per_variable_weights={ + # Any variables not specified here are weighted as 1.0. + # A single-level variable, but an important headline variable + # and also one which we have struggled to get good performance + # on at short lead times, so leaving it weighted at 1.0, equal + # to the multi-level variables: + '2m_temperature': 1.0, + # New single-level variables, which we don't weight too highly + # to avoid hurting performance on other variables. + '10m_u_component_of_wind': 0.1, + '10m_v_component_of_wind': 0.1, + 'mean_sea_level_pressure': 0.1, + 'sea_surface_temperature': 0.1, + 'total_precipitation_12hr': 0.1 + }, + ) + loss *= self._loss_weighting(noise_levels) + return loss, diagnostics + + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + if self._sampler_config is None: + raise ValueError( + 'Sampler config must be specified to run inference on GenCast.' + ) + if self._sampler is None: + self._sampler = dpm_solver_plus_plus_2s.Sampler( + self._preconditioned_denoiser, **self._sampler_config + ) + return self._sampler(inputs, targets_template, forcings, **kwargs) diff --git a/graphcast/icosahedral_mesh.py b/graphcast/icosahedral_mesh.py index 4c43642b..62837834 100644 --- a/graphcast/icosahedral_mesh.py +++ b/graphcast/icosahedral_mesh.py @@ -279,3 +279,7 @@ def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]]) receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]]) return senders, receivers + + +def get_last_triangular_mesh_for_sphere(splits: int) -> TriangularMesh: + return get_hierarchy_of_triangular_meshes_for_sphere(splits=splits)[-1] diff --git a/graphcast/mlp.py b/graphcast/mlp.py new file mode 100644 index 00000000..a60eb661 --- /dev/null +++ b/graphcast/mlp.py @@ -0,0 +1,45 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constructors for MLPs.""" + +import haiku as hk +import jax +import jax.numpy as jnp + + +# TODO(aelkadi): Move the mlp factory here from `deep_typed_graph_net.py`. + + +class LinearNormConditioning(hk.Module): + """Module for norm conditioning. + + Conditions the normalization of "inputs" by applying a linear layer to the + "norm_conditioning" which produces the scale and variance which are applied to + each channel (across the last dim) of "inputs". + """ + + def __init__(self, name="norm_conditioning"): + super().__init__(name=name) + + def __call__(self, inputs: jax.Array, norm_conditioning: jax.Array): + + feature_size = inputs.shape[-1] + conditional_linear_layer = hk.Linear( + output_size=2 * feature_size, + w_init=hk.initializers.TruncatedNormal(stddev=1e-8), + ) + conditional_scale_offset = conditional_linear_layer(norm_conditioning) + scale_minus_one, offset = jnp.split(conditional_scale_offset, 2, axis=-1) + scale = scale_minus_one + 1. + return inputs * scale + offset diff --git a/graphcast/model_utils.py b/graphcast/model_utils.py index 949088c0..11209f11 100644 --- a/graphcast/model_utils.py +++ b/graphcast/model_utils.py @@ -15,6 +15,7 @@ from typing import Mapping, Optional, Tuple +import jax.numpy as jnp import numpy as np from scipy.spatial import transform import xarray @@ -722,3 +723,36 @@ def stacked_to_dataset( name=template_var.name, ) return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count + + +def fourier_features( + values: jnp.ndarray, + base_period: float, + num_frequencies: int, + ) -> jnp.ndarray: + """Maps values to sin/cos features for a range of frequencies. + + Args: + values: Values to compute Fourier features for. + base_period: The base period to use. This should be greater or equal to the + range of the values, or to the period if the values have periodic + semantics (e.g. 2pi if they represent angles). Frequencies used will be + integer multiples of 1/base_period. + num_frequencies: The number of frequencies to use, we will use integer + multiples of 1/base_period from 1 up to num_frequencies inclusive. (We + don't include a zero frequency as this would just give constant features + which are redundant if a bias term is present). + + Returns: + Array with same shape as values except with an extra trailing dimension + of size 2*num_frequencies, which contains a sin and a cos feature for each + frequency. + """ + frequencies = np.arange(1, num_frequencies + 1) / base_period + angular_frequencies = jnp.array(2 * np.pi * frequencies, dtype=values.dtype) + values_times_angular_freqs = values[..., None] * angular_frequencies + return jnp.concatenate( + [jnp.cos(values_times_angular_freqs), + jnp.sin(values_times_angular_freqs)], + axis=-1) + diff --git a/graphcast/nan_cleaning.py b/graphcast/nan_cleaning.py new file mode 100644 index 00000000..ced9e9fc --- /dev/null +++ b/graphcast/nan_cleaning.py @@ -0,0 +1,125 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrappers for Predictors which allow them to work with data cleaned of NaNs. + +The Predictor which is wrapped sees inputs and targets without NaNs, and makes +NaNless predictions. +""" + +from typing import Optional, Tuple + +from graphcast import predictor_base as base +import numpy as np +import xarray + + +class NaNCleaner(base.Predictor): + """A predictor wrapper than removes NaNs from ingested data. + + The Predictor which is wrapped sees inputs and targets without NaNs. + """ + + def __init__( + self, + predictor: base.Predictor, + var_to_clean: str, + fill_value: xarray.Dataset, + reintroduce_nans: bool = False, + ): + """Initializes the NaNCleaner.""" + self._predictor = predictor + self._fill_value = fill_value[var_to_clean] + self._var_to_clean = var_to_clean + self._reintroduce_nans = reintroduce_nans + + def _clean(self, dataset: xarray.Dataset) -> xarray.Dataset: + """Cleans the dataset of NaNs.""" + data_array = dataset[self._var_to_clean] + dataset = dataset.assign( + {self._var_to_clean: data_array.fillna(self._fill_value)} + ) + return dataset + + def _maybe_reintroduce_nans( + self, stale_inputs: xarray.Dataset, predictions: xarray.Dataset + ) -> xarray.Dataset: + # NaN positions don't change between input frames, if they do then + # we should be more careful about re-introducing them. + if self._var_to_clean in predictions.keys(): + nan_mask = np.isnan(stale_inputs[self._var_to_clean]).any(dim='time') + with_nan_values = predictions[self._var_to_clean].where(~nan_mask, np.nan) + predictions = predictions.assign({self._var_to_clean: with_nan_values}) + return predictions + + def __call__( + self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs, + ) -> xarray.Dataset: + if self._reintroduce_nans: + # Copy inputs before cleaning so that we can reintroduce NaNs later. + original_inputs = inputs.copy() + if self._var_to_clean in inputs.keys(): + inputs = self._clean(inputs) + if forcings and self._var_to_clean in forcings.keys(): + forcings = self._clean(forcings) + predictions = self._predictor( + inputs, targets_template, forcings, **kwargs + ) + if self._reintroduce_nans: + predictions = self._maybe_reintroduce_nans(original_inputs, predictions) + return predictions + + def loss( + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs, + ) -> base.LossAndDiagnostics: + if self._var_to_clean in inputs.keys(): + inputs = self._clean(inputs) + if self._var_to_clean in targets.keys(): + targets = self._clean(targets) + if forcings and self._var_to_clean in forcings.keys(): + forcings = self._clean(forcings) + return self._predictor.loss( + inputs, targets, forcings, **kwargs + ) + + def loss_and_predictions( + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs, + ) -> Tuple[base.LossAndDiagnostics, xarray.Dataset]: + if self._reintroduce_nans: + # Copy inputs before cleaning so that we can reintroduce NaNs later. + original_inputs = inputs.copy() + if self._var_to_clean in inputs.keys(): + inputs = self._clean(inputs) + if self._var_to_clean in targets.keys(): + targets = self._clean(targets) + if forcings and self._var_to_clean in forcings.keys(): + forcings = self._clean(forcings) + + loss, predictions = self._predictor.loss_and_predictions( + inputs, targets, forcings, **kwargs + ) + if self._reintroduce_nans: + predictions = self._maybe_reintroduce_nans(original_inputs, predictions) + return loss, predictions diff --git a/graphcast/rollout.py b/graphcast/rollout.py index b243d0fe..d10bc1c0 100644 --- a/graphcast/rollout.py +++ b/graphcast/rollout.py @@ -13,11 +13,12 @@ # limitations under the License. """Utils for rolling out models.""" -from typing import Iterator +from typing import Iterator, Optional, Sequence from absl import logging import chex import dask.array +from graphcast import xarray_jax from graphcast import xarray_tree import jax import numpy as np @@ -37,6 +38,170 @@ def __call__( ... +def _replicate_dataset( + data: xarray.Dataset, replica_dim: str, + replicate_to_device: bool, + devices: Sequence[jax.Device], + ) -> xarray.Dataset: + """Used to prepare for xarray_jax.pmap.""" + + def replicate_variable(variable: xarray.Variable) -> xarray.Variable: + if replica_dim in variable.dims: + # TODO(pricei): call device_put_replicated when replicate_to_device==True + return variable.transpose(replica_dim, ...) + else: + data = len(devices) * [variable.data] + if replicate_to_device: + assert devices is not None + # TODO(pricei): Refactor code to use "device_put_replicated" instead of + # device_put_sharded. + data = jax.device_put_sharded(data, devices) + else: + data = np.stack(data, axis=0) + return xarray_jax.Variable( + data=data, dims=(replica_dim,) + variable.dims, attrs=variable.attrs + ) + + def replicate_dataset(dataset: xarray.Dataset) -> xarray.Dataset: + if dataset is None: + return None + data_variables = { + name: replicate_variable(var) + for name, var in dataset.data_vars.variables.items() + } + coords = {name: coord.variable for name, coord in dataset.coords.items()} + return xarray.Dataset(data_variables, coords=coords, attrs=dataset.attrs) + + return replicate_dataset(data) + + +def chunked_prediction_generator_multiple_runs( + predictor_fn: PredictorFn, + rngs: chex.PRNGKey, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset], + num_samples: Optional[int], + pmap_devices: Optional[Sequence[jax.Device]] = None, + **chunked_prediction_kwargs, +) -> Iterator[xarray.Dataset]: + """Outputs a trajectory of multiple samples by yielding chunked predictions. + + Args: + predictor_fn: Function to use to make predictions for each chunk. + rngs: RNG sequence to be used for each ensemble member. + inputs: Inputs for the model. + targets_template: Template for the target prediction, requires targets + equispaced in time. + forcings: Optional forcing for the model. + num_samples: The number of runs / samples to rollout. + pmap_devices: List of devices over which predictor_fn is pmapped, or None if + it is not pmapped. + **chunked_prediction_kwargs: + See chunked_prediction, some of these are required arguments. + + Yields: + The predictions for each chunked step of the chunked rollout, such that + if all predictions are concatenated in time and sample dimension squeezed, + this would match the targets template in structure. + + """ + if pmap_devices is not None: + assert ( + num_samples % jax.device_count() == 0 + ), "num_samples must be a multiple of jax.device_count()" + + def predictor_fn_pmap_named_args(rng, inputs, targets_template, forcings): + targets_template = _replicate_dataset( + targets_template, + replica_dim="sample", + replicate_to_device=True, + devices=pmap_devices, + ) + return predictor_fn(rng, inputs, targets_template, forcings) + + for i in range(0, num_samples, jax.device_count()): + sample_idx = slice(i, i + jax.device_count()) + logging.info("Samples %s out of %s", sample_idx, num_samples) + logging.flush() + sample_group_rngs = rngs[sample_idx] + + if "sample" not in inputs.dims: + sample_inputs = inputs + else: + sample_inputs = inputs.isel(sample=sample_idx, drop=True) + + sample_inputs = _replicate_dataset( + sample_inputs, + replica_dim="sample", + replicate_to_device=True, + devices=pmap_devices, + ) + + if forcings is not None: + if "sample" not in forcings.dims: + sample_forcings = forcings + else: + sample_forcings = forcings.isel(sample=sample_idx, drop=True) + + # TODO(pricei): We are replicating the full forcings for all rollout + # timesteps here, rather than inside `predictor_fn_pmap_named_args` like + # the targets_template above, because the forcings are concatenated with + # the inputs which will already be replicated. We should refactor this + # so that chunked prediction is aware of whether it is being run with + # pmap, and if so do the replication and device_put only of the + # necessary timesteps, as part of the chunked prediction function. + sample_forcings = _replicate_dataset( + sample_forcings, + replica_dim="sample", + replicate_to_device=False, + devices=pmap_devices, + ) + else: + sample_forcings = None + + for prediction_chunk in chunked_prediction_generator( + predictor_fn=predictor_fn_pmap_named_args, + rng=sample_group_rngs, + inputs=sample_inputs, + targets_template=targets_template, + forcings=sample_forcings, + pmap_devices=pmap_devices, + **chunked_prediction_kwargs, + ): + prediction_chunk.coords["sample"] = np.arange( + sample_idx.start, sample_idx.stop, sample_idx.step + ) + yield prediction_chunk + del prediction_chunk + else: + for i in range(num_samples): + logging.info("Sample %d/%d", i, num_samples) + logging.flush() + this_sample_rng = rngs[i] + + if "sample" in inputs.dims: + sample_inputs = inputs.isel(sample=i, drop=True) + else: + sample_inputs = inputs + + sample_forcings = forcings + if sample_forcings is not None: + if "sample" in sample_forcings.dims: + sample_forcings = sample_forcings.isel(sample=i, drop=True) + + for prediction_chunk in chunked_prediction_generator( + predictor_fn=predictor_fn, + rng=this_sample_rng, + inputs=sample_inputs, + targets_template=targets_template, + forcings=sample_forcings, + **chunked_prediction_kwargs): + prediction_chunk.coords["sample"] = i + yield prediction_chunk + del prediction_chunk + + def chunked_prediction( predictor_fn: PredictorFn, rng: chex.PRNGKey, @@ -85,6 +250,7 @@ def chunked_prediction_generator( forcings: xarray.Dataset, num_steps_per_chunk: int = 1, verbose: bool = False, + pmap_devices: Optional[Sequence[jax.Device]] = None ) -> Iterator[xarray.Dataset]: """Outputs a long trajectory by yielding chunked predictions. @@ -99,6 +265,8 @@ def chunked_prediction_generator( at each call of `predictor_fn`. It must evenly divide the number of steps in `targets_template`. verbose: Whether to log the current chunk being predicted. + pmap_devices: List of devices over which predictor_fn is pmapped, or None if + it is not pmapped. Yields: The predictions for each chunked step of the chunked rollout, such as @@ -140,6 +308,19 @@ def chunked_prediction_generator( time=slice(0, num_steps_per_chunk)) current_inputs = inputs + + def split_rng_fn(rng): + # Note, this is *not* equivalent to `return jax.random.split(rng)`, because + # by assigning to a tuple, the single numpy array returned by + # `jax.random.split` actually gets split into two arrays, so when calling + # the function with pmap the output is Tuple[Array, Array], where the + # leading axis of each array is `num devices`. + rng1, rng2 = jax.random.split(rng) + return rng1, rng2 + + if pmap_devices is not None: + split_rng_fn = jax.pmap(split_rng_fn, devices=pmap_devices) + for chunk_index in range(num_chunks): if verbose: logging.info("Chunk %d/%d", chunk_index, num_chunks) @@ -160,13 +341,24 @@ def chunked_prediction_generator( current_forcings = current_forcings.assign_coords(time=targets_chunk_time) current_forcings = current_forcings.compute() # Make predictions for the chunk. - rng, this_rng = jax.random.split(rng) + rng, this_rng = split_rng_fn(rng) predictions = predictor_fn( rng=this_rng, inputs=current_inputs, targets_template=current_targets_template, forcings=current_forcings) + # In the pmapped case, profiling reveals that the predictions, forcings and + # inputs are all copied onto a single TPU, causing OOM. To avoid this + # we pull all of the input/output data off the devices. This will have + # some performance impact, but maximise the memory efficiency. + # TODO(aelkadi): Pmap `_get_next_inputs` when running under pmap, and + # remove the device_get. + if pmap_devices is not None: + predictions = jax.device_get(predictions) + current_forcings = jax.device_get(current_forcings) + current_inputs = jax.device_get(current_inputs) + next_frame = xarray.merge([predictions, current_forcings]) next_inputs = _get_next_inputs(current_inputs, next_frame) diff --git a/graphcast/samplers_base.py b/graphcast/samplers_base.py new file mode 100644 index 00000000..6d60aa83 --- /dev/null +++ b/graphcast/samplers_base.py @@ -0,0 +1,47 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class for diffusion samplers.""" + +import abc +from typing import Optional + +from graphcast import denoisers_base +import xarray + + +class Sampler(abc.ABC): + """A sampling algorithm for a denoising diffusion model. + + This is constructed with a denoising function, and uses it to draw samples. + """ + + _denoiser: denoisers_base.Denoiser + + def __init__(self, denoiser: denoisers_base.Denoiser): + """Constructs Sampler. + + Args: + denoiser: A Denoiser which has been trained with an MSE loss to predict + the noise-free targets. + """ + self._denoiser = denoiser + + @abc.abstractmethod + def __call__( + self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: Optional[xarray.Dataset] = None, + **kwargs) -> xarray.Dataset: + """Draws a sample using self._denoiser. Contract like Predictor.__call__.""" diff --git a/graphcast/samplers_utils.py b/graphcast/samplers_utils.py new file mode 100644 index 00000000..7b0125d1 --- /dev/null +++ b/graphcast/samplers_utils.py @@ -0,0 +1,431 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for diffusion samplers. Makes use of dinosaur.spherical_harmonic.""" + +import dataclasses +import functools +from typing import Any, cast, Optional, Tuple + +import chex +from dinosaur import spherical_harmonic +from graphcast import xarray_jax +from graphcast import xarray_tree +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import xarray + +# Some useful constants useful when dealing with Earth's geometry. +# The earth isn't really a sphere so these are only approximate, this is the +# average radius according to https://en.wikipedia.org/wiki/Earth_radius, +# with the actual value varying from 6378 to 6357km. +EARTH_RADIUS_KM = 6371. +# And this is also approximate, but we've chosen to make it consistent with the +# radius above when modelling the earth as a sphere. This gives a value of +# around 40030; the actual value varies from 40008 to 40075. +EARTH_CIRCUMFERENCE_KM = EARTH_RADIUS_KM * 2 * np.pi + + +@dataclasses.dataclass(frozen=True) +class _ArrayGrid: + """A class that performs operations and transformations in the spectral basis. + + Attributes: + longitude_wavenumbers: num of longitude wavenumbers in the spectral basis. + total_wavenumbers: number of total wavenumbers in the spectral basis. + longitude_nodes: number of quadrature nodes along the lon direction. + latitude_nodes: number of quadrature nodes along the lat direction. + latitude_spacing: either 'gauss' or 'equiangular'. This determines the + spacing of nodal grid points in the latitudinal (north-south) direction. + """ + longitude_wavenumbers: int + total_wavenumbers: int + longitude_nodes: int + latitude_nodes: int + latitude_spacing: str + + @classmethod + def with_lat_lon( + cls, + lat: np.ndarray, + lon: np.ndarray, + ) -> '_ArrayGrid': + """_ArrayGrid for use with data in specified lat/lon grid (in degrees).""" + + latitude_nodes = lat.shape[0] + longitude_nodes = lon.shape[0] + latitude_spacing = _infer_latitude_spacing(lat) + if latitude_spacing in ['equiangular', 'gauss']: + if longitude_nodes != 2 * latitude_nodes: + # Technically not a requirement but useful to ensure `max_wavenumber` + # makes sense. + raise ValueError( + 'Unexpected number of longitude nodes. ' + f'Expected {2 * latitude_nodes}, got {longitude_nodes}') + elif latitude_spacing == 'equiangular_with_poles': + if longitude_nodes != 2 * (latitude_nodes - 1): + # Technically not a requirement but useful to ensure `max_wavenumber` + # makes s + raise ValueError( + 'Unexpected number of longitude nodes. ' + f'Expected {2 * (latitude_nodes - 1)}, got {longitude_nodes}') + else: + raise ValueError(f'Unexpected latitude_spacing={latitude_spacing}') + max_wavenumber = int(longitude_nodes // 2) - 1 + grid = cls( + longitude_wavenumbers=max_wavenumber+1, + # total_wavenumbers should be one larger than max_wavenumber as the + # wavenumbers go from 0 to max_wavenumber inclusive. + total_wavenumbers=max_wavenumber+1, + longitude_nodes=longitude_nodes, + latitude_nodes=latitude_nodes, + latitude_spacing=latitude_spacing, + ) + _verify_nodal_axes(lat, lon, grid.nodal_axes) + return grid + + @functools.cached_property + def _grid(self) -> spherical_harmonic.Grid: + return spherical_harmonic.Grid( + spherical_harmonics_impl=spherical_harmonic.RealSphericalHarmonics, + **dataclasses.asdict(self), + ) + + @functools.cached_property + def nodal_axes(self) -> Tuple[np.ndarray, np.ndarray]: + """Longitude and sin(latitude) coordinates of the nodal basis.""" + return self._grid.nodal_axes + + @functools.cached_property + def modal_axes(self) -> Tuple[np.ndarray, np.ndarray]: + """Longitudinal and total wavenumbers (m, l) of the modal basis.""" + return self._grid.modal_axes + + def to_nodal(self, x: chex.Array) -> chex.Array: + """Maps `x` from a modal to nodal representation.""" + return self._grid.to_nodal(x) + + +def _infer_latitude_spacing(lat: np.ndarray) -> str: + """Infers the type of latitude spacing given the latitude.""" + if not np.all(np.diff(lat) > 0.): + raise ValueError('Latitude values are expected to be sorted.') + + if np.allclose(np.diff(lat), lat[1] - lat[0]): + if np.isclose(max(lat), 90.): + spacing = 'equiangular_with_poles' + else: + spacing = 'equiangular' + else: + spacing = 'gauss' + return spacing + + +def _verify_nodal_axes(lat_coords: np.ndarray, lon_coords: np.ndarray, + nodal_axes: Tuple[np.ndarray, np.ndarray]): + nodal_axes_lon, nodal_axes_sin_lat = nodal_axes + if not np.allclose(nodal_axes_sin_lat, np.sin(np.deg2rad(lat_coords))): + raise ValueError( + "Latitude coords don't match those used by " + "spherical_harmonic.SphericalHarmonicBasis.") + if not np.allclose(nodal_axes_lon, np.deg2rad(lon_coords)): + raise ValueError( + "Longitude coords don't match those used by " + "spherical_harmonic.SphericalHarmonicBasis.") + + +class Grid: + """xarray wrapper around _ArrayGrid.""" + + @classmethod + def for_nodal_data( + cls, + nodal_data: xarray.DataArray, + ) -> 'Grid': + """A Grid for use with a given shape of nodal (lat/lon grid) data. + + This uses the maximum number of spherical harmonics that the grid is able + to resolve. + + This class supports data arrays with latitude spacings as defined by + "dinosaur.spherical_harmonic". In summary: + * 'equiangular': equally spaced (by `d_lat`) values between -90 + d_lat and + 90 - d_lat / 2. In our case, longitude must also be spaced by `d_lat`. + * 'equiangular_with_poles': equally spaced (by `d_lat`) values between -90 + and 90. In our case, longitude must also be spaced by `d_lat`. + * 'gauss': Gauss-Legendre nodes. + + Args: + nodal_data: An xarray with 'lat' and 'lon' dimensions and coordinates in + degrees. + + Returns: + A grid with the specified latitude_nodes, with + longitude_nodes=2*latitude_nodes and max_wavenumber=latitude_nodes-1. + """ + + grid = _ArrayGrid.with_lat_lon( + nodal_data.coords['lat'].data, + nodal_data.coords['lon'].data) + return cls(grid, + nodal_data.coords['lat'].data, + nodal_data.coords['lon'].data) + + def __init__(self, + grid: _ArrayGrid, + lat_coords: np.ndarray, + lon_coords: np.ndarray): + _verify_nodal_axes(lat_coords, lon_coords, grid.nodal_axes) + self._underlying = grid + # Record the exact original lat/lon coords so we can return them exactly + # from an inverse transform, avoiding any xarray merge issues if coordinates + # are off by a rounding error. + self._lat_coords = lat_coords + self._lon_coords = lon_coords + self._longitude_wavenumber_coords, self._total_wavenumber_coords = ( + grid.modal_axes) + + @property + def total_wavenumber_coords(self) -> xarray.DataArray: + """Coords that must be used for 'total_wavenumber' dimension.""" + return xarray.DataArray( + data=self._total_wavenumber_coords, + dims=('total_wavenumber',), + coords={'total_wavenumber': self._total_wavenumber_coords}) + + @property + def longitude_wavenumber_coords(self) -> xarray.DataArray: + """Coords that must be used for 'longitude_wavenumber' dimension.""" + return xarray.DataArray( + data=self._longitude_wavenumber_coords, + dims=('longitude_wavenumber',), + coords={'longitude_wavenumber': self._longitude_wavenumber_coords}) + + def to_nodal( + self, modal_data: xarray.DataArray) -> xarray.DataArray: + """Applies the inverse spherical harmonic transform. + + Args: + modal_data: A tree of xarray.DataArray with 'longitude_wavenumber' and + 'total_wavenumber' dimensions with coords + `self.longitude_wavenumber_coords` and `self.total_wavenumber_coords` + respectively, and with the same sparsity pattern described under + `to_modal`. + + Returns: + Corresponding tree where the 'longitude_wavenumber' and + 'total_wavenumber' dimensions are replaced by 'lat', 'lon' dimensions. + """ + def inverse_transform(modal: xarray.DataArray) -> xarray.DataArray: + if (not np.all(modal.coords['longitude_wavenumber'] == + self._longitude_wavenumber_coords) or + not np.all(modal.coords['total_wavenumber'] == + self._total_wavenumber_coords)): + raise ValueError('Wavenumber coords don\'t follow required convention.') + + return xarray_jax.apply_ufunc( + self._underlying.to_nodal, modal, + input_core_dims=[['longitude_wavenumber', 'total_wavenumber']], + output_core_dims=[['lon', 'lat']], + ).assign_coords( + lon=self._lon_coords, + lat=self._lat_coords, + ) + + return xarray_tree.map_structure(inverse_transform, modal_data) + + +def sample( + key: jnp.ndarray, + power_spectrum: xarray.DataArray, + template: xarray.DataArray, + grid: Optional[Grid] = None, + ) -> xarray.DataArray: + """Samples Gaussian Process noise on a sphere, with a given power spectrum. + + This means the noise will have the given power spectrum *in expectation*; the + power spectrum of individual samples may vary. + + The noise will be isotropic, meaning the distribution is invariant to + rotations of the sphere. + + The marginal variance of the returned values will be equal to the total power, + i.e. the sum of power_spectrum. So if you want unit marginal variance, just + make sure to normalize the power_spectrum to sum to 1. + + Args: + key: JAX rng key. + power_spectrum: An array with shape (total_wavenumber,) giving the power + which is desired at each total wavenumber (corresponding to a wavelength + EARTH_CIRCUMFERENCE/total_wavenumber) for total wavenumbers 0 up to some + maximum. This is in squared units of the quantity being sampled. + template: An array with the shape that you want the samples in, containing + 'lat' and 'lon' dimensions. If other dimensions are present, we draw + multiple independent samples along these other dimensions. + grid: spherical_harmonic.Grid on which to sample the noise. If not specified + a grid will be created based on `template`, however note you may save some + RAM and compute by re-using a single Grid instance across multiple calls. + + Returns: + DataArray with the same shape as template. + """ + if grid is None: + grid = Grid.for_nodal_data(template) + dims = [d for d in template.dims if d not in ('lat', 'lon')] + shape = [template.sizes[d] for d in dims] + coords = {name: coord for name, coord in template.coords.items() + if name not in ('lat', 'lon')} + dims.extend(('total_wavenumber', 'longitude_wavenumber')) + shape.extend((len(grid.total_wavenumber_coords), + len(grid.longitude_wavenumber_coords))) + coords.update({'total_wavenumber': grid.total_wavenumber_coords, + 'longitude_wavenumber': grid.longitude_wavenumber_coords}) + coeffs = xarray_jax.DataArray( + data=jax.random.normal(key, shape), dims=dims, coords=coords) + # Mask out coefficients which are out of range. This broadcasts to a + # triangular mask with shape (total_wavenumber, longitude_wavenumber): + mask = ( + abs(coeffs.longitude_wavenumber) <= coeffs.total_wavenumber + ).astype(np.float32) + # For total_wavenumber t, there will be 2t+1 non-zero coefficients at + # different longitude_wavenumbers. We must normalize the coefficients so that + # summing their squares at each total_wavenumber, sums to the corresponding + # value in the power spectrum: + multiplier = mask * np.sqrt(power_spectrum / mask.sum( + 'longitude_wavenumber', skipna=False)) + # And a standard normalization factor used in this implementation of the + # spherical harmonic transform: + multiplier *= np.sqrt(4 * np.pi) + # Only finally multiply by coeffs to avoid too many broadcasting + # multiplications: + coeffs *= multiplier + result = cast(xarray.DataArray, grid.to_nodal(coeffs)) + result = result.astype(template.dtype) + return result.transpose(*template.dims) + + +def spherical_white_noise_like(template: xarray.Dataset) -> xarray.Dataset: + """Samples isotropic mean 0 variance 1 white noise on the sphere.""" + def spherical_white_noise_like_dataarray(data_array: xarray.DataArray + ) -> xarray.DataArray: + num_wavenumbers = data_array.lon.shape[0] // 2 + key = hk.next_rng_key() + return sample( + key=key, + power_spectrum=xarray_jax.DataArray( + data=np.array([1/num_wavenumbers for _ in range(num_wavenumbers)]), + dims=['total_wavenumber']), + template=data_array) + return template.map(spherical_white_noise_like_dataarray) + + +def rho_inverse_cdf( + min_value: float, + max_value: float, + rho: float, + cdf: Any) -> Any: + """Quantiles of rho distribution used for noise levels at sampling time. + + This is parameterised by rho as in Eqn 5 from the Elucidating paper + (but with max/min flipped so that quantiles are given in ascending not + descending order). It's equivalent to a Beta[rho, 1] distribution rescaled to + [min_value, max_value]. + + At sampling time we use noise levels at fixed quantiles of this distribution. + Unlike in the paper, we also use the same distribution for noise levels at + training time (albeit potentially with different parameters, and sampling from + it at random). + + Args: + min_value: + max_value: + Define the support of the distribution. + rho: + Shape parameter. + cdf: + Value or values between 0 and 1 indicating which quantile you want. Can + be a numpy or jax array. + + Returns: + Quantiles of the distribution, with same shape/type as `cdf`. + """ + return ( + min_value**(1 / rho) + cdf * + (max_value**(1 / rho) - min_value**(1 / rho)) + )**rho + + +def tree_where( + cond: jnp.ndarray, + xs: Any, + ys: Any + ) -> Any: + """Like jnp.where but works with trees for xs and ys (but not for cond).""" + return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), xs, ys) + + +def noise_schedule( + max_noise_level: float = 80., + min_noise_level: float = 0.002, + num_noise_levels: int = 30, + rho: float = 7., +) -> np.ndarray: + """Computes a descending noise schedule for sampling, ending with zero.""" + noise_levels = rho_inverse_cdf( + min_value=min_noise_level, + max_value=max_noise_level, + rho=rho, + # We want the noise levels in descending order, so ask for quantiles + # 1 down to 0: + cdf=np.linspace(1, 0, num_noise_levels)) + # The final zero noise level is somewhat special-cased. We don't actually + # denoise from this noise level but appending it here is convenient for + # sampling loop implementations. + return np.append(noise_levels, 0.) + + +def stochastic_churn_rate_schedule( + noise_levels: np.ndarray, + stochastic_churn_rate: float = 0., + churn_min_noise_level: float = 0.05, + churn_max_noise_level: float = 50.0, +) -> np.ndarray: + """Computes a stochastic churn rate for each noise level.""" + num_noise_levels = len(noise_levels)-1 # Exclude final zero noise level. + # As in the Elucidated Diffusion paper, clamp this so it doesn't increase the + # variance by a factor of more than 2, no matter how few noise levels are + # used: + per_step_churn_rate = min(stochastic_churn_rate / num_noise_levels, + np.sqrt(2) - 1) + return ( + (churn_min_noise_level <= noise_levels[:-1]) & + (noise_levels[:-1] <= churn_max_noise_level) + ) * per_step_churn_rate + + +def apply_stochastic_churn( + x: Any, + noise_level: jax.typing.ArrayLike, + stochastic_churn_rate: jax.typing.ArrayLike, + noise_level_inflation_factor: jax.typing.ArrayLike, +) -> tuple[Any, jax.typing.ArrayLike]: + """Returns x at higher noise level, and the higher noise level itself.""" + # We increase the noise level of x a bit before taking it down again: + new_noise_level = noise_level * (1.0 + stochastic_churn_rate) + extra_noise_stddev = (jnp.sqrt(new_noise_level**2 - noise_level**2) + * noise_level_inflation_factor) + updated_x = x + spherical_white_noise_like(x) * extra_noise_stddev + return updated_x, new_noise_level + diff --git a/graphcast/sparse_transformer.py b/graphcast/sparse_transformer.py new file mode 100644 index 00000000..502bd840 --- /dev/null +++ b/graphcast/sparse_transformer.py @@ -0,0 +1,577 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformer with either dense or sparse attention. + +The sparse attention implemented here is for nodes to attend only to themselves +and their neighbours on the graph). It assumes that the adjacency matrix has a +banded structure, and is implemented with dense operations computing with only +the diagonal, super diagonal, and subdiagonal blocks of the tri-block-diagonal +attention matrix. + +The basic model structure of the transformer and some functions were adapted +from xlm's transformer_simple.py. +""" + +import dataclasses +import logging +from typing import Any, Callable, Literal, Optional, Tuple + +from graphcast import mlp as mlp_builder +from graphcast import sparse_transformer_utils as utils +import haiku as hk +import jax +from jax.experimental.pallas.ops.tpu import splash_attention +import jax.numpy as jnp +import numpy as np +import scipy as sp + + +@dataclasses.dataclass +class _ModelConfig: + """Transformer config.""" + # Depth, or num transformer blocks. One 'layer' is attn + ffw. + num_layers: int + # Primary width, the number of channels on the carrier path. + d_model: int + # Number of heads for self-attention. + num_heads: int + # Mask block size. + mask_block_size: int + # Attention type - 'mha' or 'triblockdiag_mha' + attention_type: str = 'triblockdiag_mha' + block_q: Optional[int] = None + block_kv: Optional[int] = None + block_kv_compute: Optional[int] = None + block_q_dkv: Optional[int] = None + block_kv_dkv: Optional[int] = None + block_kv_dkv_compute: Optional[int] = None + # mask type if splash attention being used - 'full' or 'lazy' + mask_type: Optional[str] = 'full' + # Number of channels per-head for self-attn QK computation. + key_size: Optional[int] = None + # Number of channels per-head for self-attn V computation. + value_size: Optional[int] = None + # Activation to use, any in jax.nn. + activation: str = 'gelu' + # Init scale for ffw layers (divided by num_layers) + ffw_winit_mult: float = 2.0 + # Init scale for final ffw layer (divided by depth) + ffw_winit_final_mult: float = 2.0 + # Init scale for mha proj (divided by depth). + attn_winit_mult: float = 2.0 + # Init scale for mha w (divided by depth). + attn_winit_final_mult: float = 2.0 + # Number of hidden units in the MLP blocks. Defaults to 4 * d_model. + ffw_hidden: Optional[int] = None + + def __post_init__(self): + if self.ffw_hidden is None: + self.ffw_hidden = 4 * self.d_model + # Compute key_size and value_size from d_model // num_heads. + if self.key_size is None: + if self.d_model % self.num_heads != 0: + raise ValueError('num_heads has to divide d_model exactly') + self.key_size = self.d_model // self.num_heads + if self.value_size is None: + if self.d_model % self.num_heads != 0: + raise ValueError('num_heads has to divide d_model exactly') + self.value_size = self.d_model // self.num_heads + + +def get_mask_block_size(mask: sp.sparse.csr_matrix) -> int: + """Get blocksize of the adjacency matrix (attn mask) for the permuted mesh.""" + # sub-diagonal bandwidth + lbandwidth = ( + np.arange(mask.shape[0]) - (mask != 0).argmax(axis=0) + 1).max() + # super-diagonal bandwidth + ubandwidth = ( + (mask.shape[0]-1) - np.argmax(mask[::-1,:] != 0, axis=0 + ) - np.arange(mask.shape[0]) + 1).max() + block_size = np.maximum(lbandwidth, ubandwidth) + return block_size + + +def ffw(x: jnp.ndarray, cfg: _ModelConfig) -> jnp.ndarray: + """Feed-forward block.""" + ffw_winit = hk.initializers.VarianceScaling(cfg.ffw_winit_mult / + cfg.num_layers) + ffw_winit_final = hk.initializers.VarianceScaling(cfg.ffw_winit_final_mult / + cfg.num_layers) + x = hk.Linear(cfg.ffw_hidden, name='ffw_up', w_init=ffw_winit)(x) + x = getattr(jax.nn, cfg.activation)(x) + return hk.Linear(cfg.d_model, name='ffw_down', w_init=ffw_winit_final)(x) + + +def triblockdiag_softmax(logits: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Softmax given the diag, upper diag, and lower diag logit blocks.""" + + logits_d, logits_u, logits_l = logits + + m = jnp.max(jnp.stack([ + jax.lax.stop_gradient(logits_d.max(-1, keepdims=True)), + jax.lax.stop_gradient(logits_u.max(-1, keepdims=True)), + jax.lax.stop_gradient(logits_l.max(-1, keepdims=True))]), axis=0) + + unnormalized_d = jnp.exp(logits_d - m) + unnormalized_u = jnp.exp(logits_u - m) + unnormalized_l = jnp.exp(logits_l - m) + + denom = ( + unnormalized_d.sum(-1, keepdims=True) + + unnormalized_u.sum(-1, keepdims=True) + + unnormalized_l.sum(-1, keepdims=True) + ) + + logits_d = unnormalized_d / denom + logits_u = unnormalized_u / denom + logits_l = unnormalized_l / denom + + return (logits_d, logits_u, logits_l) + + +def triblockdiag_mha(q_input: jnp.ndarray, kv_input: jnp.ndarray, + mask: jnp.ndarray, cfg: _ModelConfig, + ) -> jnp.ndarray: + """Triblockdiag multihead attention.""" + + # q_inputs, kv_input: (batch, num_blocks, block_size, num_heads, d_model) + q = multihead_linear(q_input, 'q', cfg) + k = multihead_linear(kv_input, 'k', cfg) + v = multihead_linear(kv_input, 'v', cfg) + + k = jnp.pad(k, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0))) + v = jnp.pad(v, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0))) + + def qk_prod(queries, keys): + return jnp.einsum('bnqhd,bnkhd->bnhqk', queries, keys) + + # q shape is (batch, num_blocks, block_size, num_heads, qk_dim) + # k shape is (batch, num_blocks + 2, block_size, num_heads, qk_dim) + logits_d = qk_prod(q, k[:, 1:-1, ...]) * cfg.key_size**-0.5 + logits_u = qk_prod(q, k[:, 2:, ...]) * cfg.key_size**-0.5 + logits_l = qk_prod(q, k[:, :-2, ...]) * cfg.key_size**-0.5 + + # apply mask + logits_d = jnp.where(mask[:, 0, ...], logits_d, -1e30) + logits_u = jnp.where(mask[:, 1, ...], logits_u, -1e30) + logits_l = jnp.where(mask[:, 2, ...], logits_l, -1e30) + + logits_d, logits_u, logits_l = utils.wrap_fn_for_upcast_downcast( + (logits_d, logits_u, logits_l), + triblockdiag_softmax + ) + + def av_prod(attn_weights, values): + return jnp.einsum('bnhqk,bnkhd->bnqhd', attn_weights, values) + + out_d = av_prod(logits_d, v[:, 1:-1, ...]) + out_u = av_prod(logits_u, v[:, 2:, ...]) + out_l = av_prod(logits_l, v[:, :-2, ...]) + # x shape is (batch, num_blocks, block_size, num_heads, d_model) + x = out_d + out_u + out_l + + x = jnp.reshape(x, x.shape[:-2] + (cfg.num_heads * cfg.value_size,)) + attn_winit_final = hk.initializers.VarianceScaling( + cfg.attn_winit_final_mult / cfg.num_layers) + x = hk.Linear(cfg.d_model, name='mha_final', w_init=attn_winit_final)(x) + return x + + +def multihead_linear( + x: jnp.ndarray, qkv: str, cfg: _ModelConfig +) -> jnp.ndarray: + """Linearly project `x` to have `head_size` dimensions per head.""" + head_size = cfg.value_size if qkv == 'v' else cfg.key_size + attn_winit = hk.initializers.VarianceScaling(cfg.attn_winit_mult / + cfg.num_layers) + out = hk.Linear( + cfg.num_heads * head_size, + w_init=attn_winit, + name='mha_proj_' + qkv, + with_bias=False, + )(x) + shape = out.shape[:-1] + (cfg.num_heads, head_size) + return jnp.reshape(out, shape) + + +def mha(q_input: jnp.ndarray, kv_input: jnp.ndarray, + mask: jnp.ndarray, cfg: _ModelConfig, + normalize_logits: bool = True, + ) -> jnp.ndarray: + """Multi head attention.""" + + q = multihead_linear(q_input, 'q', cfg) + k = multihead_linear(kv_input, 'k', cfg) + v = multihead_linear(kv_input, 'v', cfg) + + logits = jnp.einsum('bthd, bThd->bhtT', q, k) + if normalize_logits: + logits *= cfg.key_size**-0.5 + if mask is not None: + def apply_mask(m, l): + return jnp.where(m, l, -1e30) + logits = jax.vmap(jax.vmap( + apply_mask, in_axes=[None, 0]), in_axes=[None, 0])(mask, logits) + + # Wrap softmax weights for upcasting & downcasting in case of BF16 activations + weights = utils.wrap_fn_for_upcast_downcast(logits, jax.nn.softmax) + + # Note: our mask never has all 0 rows, since nodes always have self edges, + # so no need to account for that possibility explicitly. + + x = jnp.einsum('bhtT,bThd->bthd', weights, v) + x = jnp.reshape(x, x.shape[:-2] + (cfg.num_heads * cfg.value_size,)) + + attn_winit_final = hk.initializers.VarianceScaling( + cfg.attn_winit_final_mult / cfg.num_layers) + + x = hk.Linear(cfg.d_model, name='mha_final', w_init=attn_winit_final)(x) + return x + + +def _make_splash_mha( + mask, + mask_type: str, + num_heads: int, + block_q: Optional[int] = None, + block_kv: Optional[int] = None, + block_kv_compute: Optional[int] = None, + block_q_dkv: Optional[int] = None, + block_kv_dkv: Optional[int] = None, + block_kv_dkv_compute: Optional[int] = None, + tanh_soft_cap: Optional[float] = None, +) -> Callable[..., jnp.ndarray]: + """Construct attention kernel.""" + if mask_type == 'full': + mask = np.broadcast_to(mask[None], + (num_heads, *mask.shape)).astype(np.bool_) + + block_sizes = splash_attention.BlockSizes( + block_q=block_q, + block_kv=block_kv, + block_kv_compute=block_kv_compute, + block_q_dkv=block_q_dkv, + block_kv_dkv=block_kv_dkv, + block_kv_dkv_compute=block_kv_dkv_compute, + use_fused_bwd_kernel=True, + ) + attn = splash_attention.make_splash_mha(mask, block_sizes=block_sizes, + head_shards=1, + q_seq_shards=1, + attn_logits_soft_cap=tanh_soft_cap, + ) + return attn + + +def splash_mha(q_input: jnp.ndarray, kv_input: jnp.ndarray, + mask: jnp.ndarray | splash_attention.splash_attention_mask.Mask, + cfg: _ModelConfig, + tanh_soft_cap: Optional[float] = None, + normalize_q: bool = True) -> jnp.ndarray: + """Splash attention.""" + + q = multihead_linear(q_input, 'q', cfg) + k = multihead_linear(kv_input, 'k', cfg) + v = multihead_linear(kv_input, 'v', cfg) + + _, _, num_heads, head_dim = q.shape + + assert head_dim % 128 == 0 # splash attention kernel requires this + + attn = _make_splash_mha( + mask=mask, + mask_type=cfg.mask_type, + num_heads=num_heads, + block_q=cfg.block_q, + block_kv=cfg.block_kv, + block_kv_compute=cfg.block_kv_compute, + block_q_dkv=cfg.block_q_dkv, + block_kv_dkv=cfg.block_kv_dkv, + block_kv_dkv_compute=cfg.block_kv_dkv_compute, + tanh_soft_cap=tanh_soft_cap, + ) + attn = jax.vmap(attn) # Add batch axis. + + if normalize_q: + q *= cfg.key_size**-0.5 + + # (batch, nodes, num_heads, head_dim) -> (batch, num_heads, nodes, head_dim) + reformat = lambda y: y.transpose(0, 2, 1, 3) + x = attn(q=reformat(q), k=reformat(k), v=reformat(v)) + x = x.transpose(0, 2, 1, 3) + + x = jnp.reshape(x, x.shape[:-2] + (cfg.num_heads * cfg.value_size,)) + + attn_winit_final = hk.initializers.VarianceScaling( + cfg.attn_winit_final_mult / cfg.num_layers) + + x = hk.Linear(cfg.d_model, name='mha_final', w_init=attn_winit_final)(x) + return x + + +def layernorm( + x: jnp.ndarray, create_scale: bool, create_offset: bool +) -> jnp.ndarray: + return hk.LayerNorm( + axis=-1, create_scale=create_scale, create_offset=create_offset, + name='norm')(x) + + +def mask_block_diags(mask: sp.sparse.csr_matrix, + num_padding_nodes: int, + block_size: int) -> jnp.ndarray: + """Pad and reshape mask diag, super-siag and sub-diag blocks.""" + # add zero padding to mask + mask_padding_rows = sp.sparse.csr_matrix( + (num_padding_nodes, mask.shape[1]), dtype=jnp.int32) + mask = sp.sparse.vstack([mask, mask_padding_rows]) + mask_padding_cols = sp.sparse.csr_matrix( + (mask.shape[0], num_padding_nodes), dtype=jnp.int32) + mask = sp.sparse.hstack([mask, mask_padding_cols]) + + assert (mask.shape[-1] % block_size) == 0 + mask_daig_blocks = jnp.stack( + [jnp.array(mask[i * block_size : (i + 1) * block_size, + i * block_size : (i + 1) * block_size, + ].toarray()) + for i in range(mask.shape[0] // block_size)]) + mask_upper_diag_blocks = jnp.stack( + [jnp.array(mask[i * block_size : (i + 1) * block_size, + (i + 1) * block_size : (i + 2) * block_size, + ].toarray()) + for i in range(mask.shape[0] // block_size - 1)] + + [jnp.zeros((block_size, block_size), dtype=mask.dtype)]) + mask_lower_diag_blocks = jnp.stack( + [jnp.zeros((block_size, block_size), dtype=mask.dtype)] + + [jnp.array(mask[(i + 1) * block_size : (i + 2) * block_size, + i * block_size : (i + 1) * block_size, + ].toarray()) + for i in range(mask.shape[0] // block_size - 1)]) + mask = jnp.stack( + [mask_daig_blocks, mask_upper_diag_blocks, mask_lower_diag_blocks] + ) + mask = jnp.expand_dims(mask, (0, 3)) + return mask + + +def _pad_mask(mask, num_padding_nodes: Tuple[int, int]) -> jnp.ndarray: + q_padding, kv_padding = num_padding_nodes + mask_padding_rows = sp.sparse.csr_matrix( + (q_padding, mask.shape[1]), dtype=np.bool_) + mask = sp.sparse.vstack([mask, mask_padding_rows]) + mask_padding_cols = sp.sparse.csr_matrix( + (mask.shape[0], kv_padding), dtype=np.bool_) + mask = sp.sparse.hstack([mask, mask_padding_cols]) + return mask + + +class WeatherMeshMask(splash_attention.splash_attention_mask.Mask): + """Lazy local mask, prevent attention to embeddings outside window. + + Attributes: + mask: + """ + + _shape: Tuple[int, int] + mask: sp.sparse.spmatrix + + def __init__( + self, + mask: Any + ): + self._shape = mask.shape + self.mask = mask + + @property + def shape(self) -> Tuple[int, int]: + return self._shape + + def __getitem__(self, idx) -> np.ndarray: + if len(idx) != 2: + raise NotImplementedError(f'Unsupported slice: {idx}') + q_slice, kv_slice = idx + if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice): + raise NotImplementedError(f'Unsupported slice: {idx}') + + return self.mask[q_slice, kv_slice].toarray() + + +class Block(hk.Module): + """Transformer block (mha and ffw).""" + + def __init__(self, cfg, mask, num_nodes, num_padding_nodes, name=None): + super().__init__(name=name) + self._cfg = cfg + self.mask = mask + self.num_nodes = num_nodes + self.num_padding_nodes = num_padding_nodes + + def __call__(self, x, global_norm_conditioning=jax.Array): + # x shape is (batch, num_nodes, feature_dim) + def attn(x): + if self._cfg.attention_type == 'triblockdiag_mha': + # We pad -> reshape -> compute attn -> reshape -> select at each block + # so as to avoid complications involved in making the norm layers and + # ffw blocks account for the padding. However, this might be decreasing + # efficiency. + + # Add padding so that number of nodes is divisible into blocks + x = jnp.pad(x, ((0, 0), (0, self.num_padding_nodes), (0, 0))) + x = x.reshape(x.shape[0], + x.shape[1]//self._cfg.mask_block_size, + self._cfg.mask_block_size, + x.shape[-1]) + x = triblockdiag_mha(x, x, mask=self.mask, cfg=self._cfg) + x = x.reshape(x.shape[0], + self.num_nodes + self.num_padding_nodes, + x.shape[-1]) + return x[:,:self.num_nodes, :] + + elif self._cfg.attention_type == 'mha': + return mha(x, x, mask=self.mask, cfg=self._cfg) + + elif self._cfg.attention_type == 'splash_mha': + # We pad -> reshape -> compute attn -> reshape -> select at each block + # so as to avoid complications involved in making the norm layers and + # ffw blocks account for the padding. However, this might be decreasing + # efficiency. + + # Add padding so that number of nodes is divisible by block sizes. + x = jnp.pad(x, ((0, 0), (0, self.num_padding_nodes[0]), (0, 0))) + x = splash_mha(x, x, mask=self.mask, cfg=self._cfg) + return x[:,:self.num_nodes, :] + + else: + raise NotImplementedError() + + def norm_conditioning_layer(x): + return mlp_builder.LinearNormConditioning( + name=self.name+'_norm_conditioning')( + x, + norm_conditioning=jnp.expand_dims(global_norm_conditioning, 1) + ) + + x = x + attn( + norm_conditioning_layer( + layernorm(x, create_scale=False, create_offset=False) + ) + ) + x = x + ffw( + norm_conditioning_layer( + layernorm(x, create_scale=False, create_offset=False) + ), + self._cfg, + ) + return x + + +class Transformer(hk.Module): + """Main transformer module that processes embeddings. + + All but the very first and very last layer of a 'classic' Transformer: + Receives already embedded inputs instead of discrete tokens. + Outputs an embedding for each 'node'/'position' rather than logits. + """ + + def __init__(self, + adj_mat: sp.sparse.csr_matrix, + attention_k_hop: int, + attention_type: Literal['splash_mha', 'triblockdiag_mha', 'mha'], + mask_type: Literal['full', 'lazy'], + num_heads=1, + name=None, + block_q: Optional[int] = None, + block_kv: Optional[int] = None, + block_kv_compute: Optional[int] = None, + block_q_dkv: Optional[int] = None, + block_kv_dkv: Optional[int] = None, + block_kv_dkv_compute: Optional[int] = None, + **kwargs): + super().__init__(name=name) + + # Construct mask and deduce block size. + mask = adj_mat ** attention_k_hop + mask_block_size = get_mask_block_size(mask) + logging.info('mask_block_size: %s.', mask_block_size) + + if attention_type == 'triblockdiag_mha': + # we will stack the nodes in blocks of 'block_size' nodes, so we need to + # pad the input such that (num_nodes + num_padding_nodes) % block_size = 0 + self.num_padding_nodes = int(np.ceil( + mask.shape[0]/mask_block_size)*mask_block_size + - mask.shape[0]) + self.mask = mask_block_diags( + mask, self.num_padding_nodes, mask_block_size) + elif attention_type == 'splash_mha': + max_q_block_size = np.maximum(block_q, block_q_dkv) + max_kv_block_size = np.maximum(block_kv, block_kv_dkv) + q_padding = int(np.ceil( + mask.shape[0]/max_q_block_size)*max_q_block_size - mask.shape[0]) + kv_padding = int(np.ceil( + mask.shape[1]/max_kv_block_size)*max_kv_block_size - mask.shape[1]) + self.num_padding_nodes = (q_padding, kv_padding) + mask = _pad_mask(mask, self.num_padding_nodes) + if mask_type == 'lazy': + splash_mask = [ + WeatherMeshMask(mask) + for _ in range(num_heads) + ] + self.mask = splash_attention.splash_attention_mask.MultiHeadMask( + splash_mask) + elif mask_type == 'full': + self.mask = mask.toarray() + elif attention_type == 'mha': + self.mask = jnp.array(mask.toarray()) + self.num_padding_nodes = 0 + else: + raise ValueError( + 'Unsupported attention type: %s' % attention_type + ) + + # Construct config for use within class. + self._cfg = _ModelConfig( + mask_block_size=mask_block_size, + attention_type=attention_type, + mask_type=mask_type, + num_heads=num_heads, + block_q=block_q, + block_kv=block_kv, + block_kv_compute=block_kv_compute, + block_q_dkv=block_q_dkv, + block_kv_dkv=block_kv_dkv, + block_kv_dkv_compute=block_kv_dkv_compute, + **kwargs) + + def __call__(self, node_features, global_norm_conditioning: jax.Array): + # node_features expected to have shape (batch, num_nodes, d) + x = node_features + for i_layer in range(self._cfg.num_layers): + x = Block(cfg=self._cfg, mask=self.mask, + num_nodes=node_features.shape[1], + num_padding_nodes=self.num_padding_nodes, + name='block_%02d' % i_layer + )(x, global_norm_conditioning=global_norm_conditioning) + + def norm_conditioning_layer(x): + return mlp_builder.LinearNormConditioning( + name=self.name+'_final_norm_conditioning')( + x, + norm_conditioning=jnp.expand_dims(global_norm_conditioning, 1) + ) + x = norm_conditioning_layer( + layernorm(x, create_scale=False, create_offset=False) + ) + + return x diff --git a/graphcast/sparse_transformer_utils.py b/graphcast/sparse_transformer_utils.py new file mode 100644 index 00000000..4af0f33f --- /dev/null +++ b/graphcast/sparse_transformer_utils.py @@ -0,0 +1,76 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for training models in low precision.""" + +import functools +from typing import Callable, Tuple, Union + +import jax +import jax.numpy as jnp + + +# Wrappers for jax.lax.reduce_precision which is non-differentiable. +@functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2)) +def reduce_precision(x, exponent_bits, mantissa_bits): + return jax.tree_util.tree_map( + lambda y: jax.lax.reduce_precision(y, exponent_bits, mantissa_bits), x) + + +def reduce_precision_fwd(x, exponent_bits, mantissa_bits): + return reduce_precision(x, exponent_bits, mantissa_bits), None + + +def reduce_precision_bwd(exponent_bits, mantissa_bits, res, dout): + del res # Unused. + return reduce_precision(dout, exponent_bits, mantissa_bits), + + +reduce_precision.defvjp(reduce_precision_fwd, reduce_precision_bwd) + + +def wrap_fn_for_upcast_downcast(inputs: Union[jnp.ndarray, + Tuple[jnp.ndarray, ...]], + fn: Callable[[Union[jnp.ndarray, + Tuple[jnp.ndarray, ...]]], + Union[jnp.ndarray, + Tuple[jnp.ndarray, ...]]], + f32_upcast: bool = True, + guard_against_excess_precision: bool = True + ) -> Union[jnp.ndarray, + Tuple[jnp.ndarray, ...]]: + """Wraps `fn` to upcast to float32 and then downcast, for use with BF16.""" + # Do not upcast if the inputs are already in float32. + # This removes a no-op `jax.lax.reduce_precision` which is unsupported + # in jax2tf at the moment. + if isinstance(inputs, Tuple): + f32_upcast = f32_upcast and inputs[0].dtype != jnp.float32 + orig_dtype = inputs[0].dtype + else: + f32_upcast = f32_upcast and inputs.dtype != jnp.float32 + orig_dtype = inputs.dtype + + if f32_upcast: + inputs = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), inputs) + + if guard_against_excess_precision: + # This is evil magic to guard against differences in precision in the QK + # calculation between the forward pass and backwards pass. This is like + # --xla_allow_excess_precision=false but scoped here. + finfo = jnp.finfo(orig_dtype) # jnp important! + inputs = reduce_precision(inputs, finfo.nexp, finfo.nmant) + + output = fn(inputs) + if f32_upcast: + output = jax.tree_util.tree_map(lambda x: x.astype(orig_dtype), output) + return output diff --git a/graphcast/transformer.py b/graphcast/transformer.py new file mode 100644 index 00000000..6788368e --- /dev/null +++ b/graphcast/transformer.py @@ -0,0 +1,124 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Transformer model for weather predictions. + +This model wraps the a transformer model and swaps the leading two axes of the +nodes in the input graph prior to evaluating the model to make it compatible +with a [nodes, batch, ...] ordering of the inputs. +""" + +from typing import Any, Mapping, Optional + +from graphcast import typed_graph +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from scipy import sparse + + +Kwargs = Mapping[str, Any] + + +def _get_adj_matrix_for_edge_set( + graph: typed_graph.TypedGraph, + edge_set_name: str, + add_self_edges: bool, +): + """Returns the adjacency matrix for the given graph and edge set.""" + # Get nodes and edges of the graph. + edge_set_key = graph.edge_key_by_name(edge_set_name) + sender_node_set, receiver_node_set = edge_set_key.node_sets + + # Compute number of sender and receiver nodes. + sender_n_node = graph.nodes[sender_node_set].n_node[0] + receiver_n_node = graph.nodes[receiver_node_set].n_node[0] + + # Build adjacency matrix. + adj_mat = sparse.csr_matrix((sender_n_node, receiver_n_node), dtype=np.bool_) + edge_set = graph.edges[edge_set_key] + s, r = edge_set.indices + adj_mat[s, r] = True + if add_self_edges: + # Should only do this if we are certain the adjacency matrix is square. + assert sender_node_set == receiver_node_set + adj_mat[np.arange(sender_n_node), np.arange(receiver_n_node)] = True + return adj_mat + + +class MeshTransformer(hk.Module): + """A Transformer for inputs with ordering [nodes, batch, ...].""" + + def __init__(self, + transformer_ctor, + transformer_kwargs: Kwargs, + name: Optional[str] = None): + """Initialises the Transformer model. + + Args: + transformer_ctor: Constructor for transformer. + transformer_kwargs: Kwargs to pass to the transformer module. + name: Optional name for haiku module. + """ + super().__init__(name=name) + # We defer the transformer initialisation to the first call to __call__, + # where we can build the mask senders and receivers of the TypedGraph + self._batch_first_transformer = None + self._transformer_ctor = transformer_ctor + self._transformer_kwargs = transformer_kwargs + + @hk.name_like('__init__') + def _maybe_init_batch_first_transformer(self, x: typed_graph.TypedGraph): + if self._batch_first_transformer is not None: + return + self._batch_first_transformer = self._transformer_ctor( + adj_mat=_get_adj_matrix_for_edge_set( + graph=x, + edge_set_name='mesh', + add_self_edges=True, + ), + **self._transformer_kwargs, + ) + + def __call__( + self, x: typed_graph.TypedGraph, + global_norm_conditioning: jax.Array + ) -> typed_graph.TypedGraph: + """Applies the model to the input graph and returns graph of same shape.""" + + if set(x.nodes.keys()) != {'mesh_nodes'}: + raise ValueError( + f'Expected x.nodes to have key `mesh_nodes`, got {x.nodes.keys()}.' + ) + features = x.nodes['mesh_nodes'].features + if features.ndim != 3: # pytype: disable=attribute-error # jax-ndarray + raise ValueError( + 'Expected `x.nodes["mesh_nodes"].features` to be 3, got' + f' {features.ndim}.' + ) # pytype: disable=attribute-error # jax-ndarray + + # Initialise transformer and mask. + self._maybe_init_batch_first_transformer(x) + + y = jnp.transpose(features, axes=[1, 0, 2]) + y = self._batch_first_transformer(y, global_norm_conditioning) + y = jnp.transpose(y, axes=[1, 0, 2]) + x = x._replace( + nodes={ + 'mesh_nodes': x.nodes['mesh_nodes']._replace( + features=y.astype(features.dtype) + ) + } + ) + return x diff --git a/graphcast_demo.ipynb b/graphcast_demo.ipynb index 65ce6a1a..ece15fec 100644 --- a/graphcast_demo.ipynb +++ b/graphcast_demo.ipynb @@ -114,7 +114,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "4wagX1TL_f15" }, "outputs": [], @@ -122,14 +121,14 @@ "# @title Authenticate with Google Cloud Storage\n", "\n", "gcs_client = storage.Client.create_anonymous_client()\n", - "gcs_bucket = gcs_client.get_bucket(\"dm_graphcast\")" + "gcs_bucket = gcs_client.get_bucket(\"dm_graphcast\")\n", + "dir_prefix = \"graphcast/\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "5JUymx84dI2m" }, "outputs": [], @@ -258,7 +257,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "KGaJ6V9MdI2n" }, "outputs": [], @@ -266,8 +264,8 @@ "# @title Choose the model\n", "\n", "params_file_options = [\n", - " name for blob in gcs_bucket.list_blobs(prefix=\"params/\")\n", - " if (name := blob.name.removeprefix(\"params/\"))] # Drop empty string.\n", + " name for blob in gcs_bucket.list_blobs(prefix=dir_prefix+\"params/\")\n", + " if (name := blob.name.removeprefix(dir_prefix+\"params/\"))] # Drop empty string.\n", "\n", "random_mesh_size = widgets.IntSlider(\n", " value=4, min=4, max=6, description=\"Mesh size:\")\n", @@ -305,7 +303,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "lYQgrPgPdI2n" }, "outputs": [], @@ -333,7 +330,7 @@ " )\n", "else:\n", " assert source == \"Checkpoint\"\n", - " with gcs_bucket.blob(f\"params/{params_file.value}\").open(\"rb\") as f:\n", + " with gcs_bucket.blob(f\"{dir_prefix}params/{params_file.value}\").open(\"rb\") as f:\n", " ckpt = checkpoint.load(f, graphcast.CheckPoint)\n", " params = ckpt.params\n", " state = {}\n", @@ -376,7 +373,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "-DJzie5me2-H" }, "outputs": [], @@ -384,8 +380,8 @@ "# @title Get and filter the list of available example datasets\n", "\n", "dataset_file_options = [\n", - " name for blob in gcs_bucket.list_blobs(prefix=\"dataset/\")\n", - " if (name := blob.name.removeprefix(\"dataset/\"))] # Drop empty string.\n", + " name for blob in gcs_bucket.list_blobs(prefix=dir_prefix+\"dataset/\")\n", + " if (name := blob.name.removeprefix(dir_prefix+\"dataset/\"))] # Drop empty string.\n", "\n", "def data_valid_for_model(\n", " file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):\n", @@ -420,7 +416,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "Yz-ekISoJxeZ" }, "outputs": [], @@ -431,7 +426,7 @@ " raise ValueError(\n", " \"Invalid dataset file, rerun the cell above and choose a valid dataset file.\")\n", "\n", - "with gcs_bucket.blob(f\"dataset/{dataset_file.value}\").open(\"rb\") as f:\n", + "with gcs_bucket.blob(f\"{dir_prefix}dataset/{dataset_file.value}\").open(\"rb\") as f:\n", " example_batch = xarray.load_dataset(f).compute()\n", "\n", "assert example_batch.dims[\"time\"] \u003e= 3 # 2 for input, \u003e=1 for targets\n", @@ -552,18 +547,17 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "Q--ZRhpTdI2o" }, "outputs": [], "source": [ "# @title Load normalization data\n", "\n", - "with gcs_bucket.blob(\"stats/diffs_stddev_by_level.nc\").open(\"rb\") as f:\n", + "with gcs_bucket.blob(dir_prefix+\"stats/diffs_stddev_by_level.nc\").open(\"rb\") as f:\n", " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", - "with gcs_bucket.blob(\"stats/mean_by_level.nc\").open(\"rb\") as f:\n", + "with gcs_bucket.blob(dir_prefix+\"stats/mean_by_level.nc\").open(\"rb\") as f:\n", " mean_by_level = xarray.load_dataset(f).compute()\n", - "with gcs_bucket.blob(\"stats/stddev_by_level.nc\").open(\"rb\") as f:\n", + "with gcs_bucket.blob(dir_prefix+\"stats/stddev_by_level.nc\").open(\"rb\") as f:\n", " stddev_by_level = xarray.load_dataset(f).compute()" ] }, diff --git a/setup.py b/setup.py index c7ce01c0..14976ce7 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ setup( name="graphcast", - version="0.1.1", + version="0.2.0.dev", description=description, long_description=description, author="DeepMind", @@ -34,6 +34,7 @@ "chex", "colabtools", "dask", + "dinosaur-dycore", "dm-haiku", "dm-tree", "jax", @@ -46,6 +47,7 @@ "trimesh", "typing_extensions", "xarray", + "xarray_tensorstore" ], classifiers=[ "Development Status :: 3 - Alpha",

YvqU#$3hIEna8mIHDeF>#L_HBW*k}(66iKcDX@A9a8C{r;uIJPN9sMN>WVz! zps5+BMf?+EkV69pQe#2vAbT>SDkkmd&PhBg=jjA%PRrDwQI&%-qbM8J4^#JiQ<0f$ za%Rf=Po>Yw_-AgIaFK4ES!QRYB%-nfUrHALUY2{^6=6S0{>2bAoSh%WcAMpSsks$t zDA0hi(A#%0M!^#v;Xz#9dj){Rk*4&dA4h9Z8j3w9lQNbZWxE+K@K1- z0A{^(UGsa*J>!@g!U+j8@}0a8PKoIBqpKSz7OYXcTG#4WNk+UVpAYzlE}k{*>+C_CuT}G6h zJ+qvF?ZQ`9G-JvFFJC7?LYE8YGPl4eR$VAvO_~E?nw$ysRT{1uos`Ci{AaW?lHdU^6b#G;hX`Iv{9ox{ON)44Qa@!b-<@2K=K|JK~GDZ3l{ax*N zbAp0@dCGRGK>l(FAp&dDQi!j&;#(~~<^SfN{(U?rCBuDL1rt%j?zyrj?abE`Vt+=yz3S+@Zpe z{nQZ!jh9h3`p?F#{qmsD0w5~q#8m;~aY|jWm=N0|Ux;)sJTb^1Wqo?CWFm=Zx zf#mDK120lb8tW1uU5Z}BF*km`La-Z^#ie$v!D)`=`@JABCP6Ig$Rp^y;T56K91B2f zZa>ki(9s9Xp|oa>MxCzimn2Zf5b3rAzEc!N4CO`JN$oATpdok-$e@2#GBp)HOG~l@ z$>k>#nac)-={&iR-cn&vM+ODS+!6-{2G~hh>&oUP68pR!c~(SM7iU~`%xZ}yN?lH5 zAaj1r-WI{&q?D(b#Q;rkp?wXcy`<$%cZuRGq$|#8+$`)L%b=VpP_8zLa)!rIz=7mx zmTI~zy&!KbdNTiH?vu~Nv9mVbT>4=!aE zcrq;Be$+1nTRV;f#fyhqgZt^K1m0gaew<}ANP;#%(ZGM}mlgYN{8}AKtKxzR)u*Ak ztBw9EUbP3F{RY)4&=QzlF@BYL0#|YRz+&O%EtWZaE7TofZ+&Pt9thk5ASEkv|Yw@+$tird2z6MXH-4LCF`i z(0DMK1WxTAt(%xmG}(LYa=(-j2m#_cy+wz@Rn}ce;^kL0;OT!U=>NkNMPx$b{2+>lND2=JUf5Ed!C!%j`jEyeG`%Pp!$7!#>X!&jn>2M2`9$cftzQ}fUBGbxL$9$7w6>b0k0X`Hp9!F|9|AMArN^jShy=3R^K zt2@xv#8$w&IF@`bBC26_U=#z%K_6 zD_VUqZRHf-5a)tQuRXw)8zamy;1*|{k0WJD15KmX+Ldvediwy4NOB-|xv~z~(&-GR zCw?^;%vzpF9mWSX`O7ExGd!kFzMzu1EKeR#^(rf)rgTocp7?E(E9x5>C}D>*?}u7G za2OUvrHBrRVdpVpM*~YkLS1iINn<@AB|8xr4quNKIzqES6GhFL#3w{Bbs!XmY}94L z8~*7Vv>Bnccvpmqi+%2)A-xX$J2r-j=FcCuD4?B`+5U2uQhSvvtovfhxzPu^MVS!9 zUNE?_g;F_uCIE0KdgAeO63UuNg)75Q@ostwbev6B;Z1Om%HiCK5F1*VHTlip^zF$?Xy zH@;#t<|V;m0j!SXrDQLj(~gRViy}Q1wLjTE$jjUr?cSn$YfkDtN)PM(V)yUAsNjD; zd&fv=fedLV`CVp5Z5?MEW}T1Wn6k1^e(Rxs4)+0-nlMagrX|~mI&Ne9>1_TNv&j@udL`&FR4-8|lcfW9pH_FQ^!!+CO zsCWh>=+7z#S(m3P*TW@GUx`zppa>ypYP)eLol&(#TH_zO2H2CDOpj1b(@N88Ptcr8 z-WA?E=-hi^Dkmujl)LC$idlwO?dE0$Upv;-)D*B3U7!7v5BinP^HOzx-(sggehiNV zKgzrClH5hkSPKR9?(2b^uNypBx4R_0Q%+<;!BjY6Z9t#IJQJurEn^pTjXOkIm>d_Y zwl02VsNnVue~Uj6`|nP54R=!nnh^EFTfzHQYF77*#SQWdxP$(M7Z~J@>HAOKXWoAH za;#*jK^EZ^<8?gkpAlOixfMv5k`IVoSjz><580lU(Km#{mx1acz>suqiDOShyL}V< zjAzE3%e(s2fY?{%#jry1OtA~Q-ya8snqU?phi8pn*%*Vr92s@}UB~h{F=l&0_brrvo+RE|H4divg=E5hzrp z9VGvh0{(71uxw~8jAp+6jp%?sX)kV?l;)1Uh-alY=z$E@EGK$S0Zc}!P;pVQNxVnU zTkH~=ZWjfQ*LOTSha^x#-5QSZq`jG&C1N_+unQZEUNC`5M=KCm{Hc1M*un@f2STY; z@BpeOjMu`{-7fso1g5B{n%ED(Vhzfgecvomqbpa7(_6Dzfmq-))k(eHss+;3V%LM#njK1R;ysew%{_k1Rx>T22q_ zYme&~QSHdi(WgVVreY`AQl;7S(?ZFF>w|fP41#u_r(u$$hh(3@o;0#j`SO&#Y@r`; zHH=-c7~34@X#tYl-y1D%VV?)Dew& zY;YfQ-&%%P8;V@PNlNm@@FY@ZQPNz}Y|AW@QxvgpcqK90ePdIqGSB+!o%P&)W*}jV zbAF!*Hrh*`-L%?#xw>vC{)-QAs))KH8J7j)N`qCqbY5$h7QY7;Ysc#&u8A>hft+I! z2JL~41TUD46JV-K2x`yqaR0|&6m0O`$GmSCFfb4hl++yhAsEj;FSfoT9twhkfN9L> zjHgrb#I45yUMX!>TWp7y+6TY8pb>bn_>j`(8KVj*J!&3xJ4i2cJfI@b-*ozr(n!3aYr4H`tXZQJ!kH3^ekRSyvBb1$e1x7{la}bdSzIx+mtULyK zk6Y2kQV8_sA1{?IA~<@>isDdy`NT#?G3*)_i&fSj&a3<=XU@~eT!igfb}!~Ot{oVj$hNU)<1<$#LYfiLKo^ihomc0D*RxdtA zEXfVS!E?{?0hOS7PFcGG+l$4JM(aoLENqGetfy&LrBhWZE!0z3Nl(J z)l_lBL!uRsg+rbd_!&uY|7*SPJP$#hN&*rN8!b_~KAgPmJ%R}Jle*7DEiE^JyDt%Z z1Sr)Z5=ac%2ayne`&32O8TM^-vKXw?`*b)4xbrc85xD~NEd}ly9680@ZM1d%d-3@F z*ItE|lABMcxfgDhpS`2sdDxr>9(VxI=*0PGc{rzdj0RB5U?MZJ<*a4t*76Ls^;Ez! z5$?iV?(J9SjeHPYLfjCI)b4xf@u-5=8j*3cE{b4^6ORiIqc?Q5#cG>2lWLwdaP;$0 z*{-fDFoi}tgW;9CW3)DZM`8r|qQ)tB7T2}yGJg{H(Hf*nV~TvN@!2W#v(Gww1u1kW zus4~0ZFIv+R?!p&e35&RKhFpcH;!jg_h1&Gm7!c9K+yzLF;213V^CFJ^W&emr@$2Y zy&f6!&RwYiG#ChgNThEA6{qhkydDpM1l1R_L>dlT9Rb2Wpf8HEh91pUWL$X~eLsAR zBY_V4&X4p{dx%TjXzU+Oz~WC-);Xw;7O>gbTp8hBNqgy}y&5sAB>ULa3P^oM1f`M^+zcNzqN4;Y+&zuK;8A)Y)$Yo3$JyNc>KFD=sN<(TOuT{L5W@f{P&z+; z!|5GOK}@ z)y7@kyc4O$UW>OMs>yVUMZX|11yRkfvd7i@(CNV#;~Rn=AV+LQ=Kga_;q0_F;Apg1 zH{Xp4j1o5{*QT6eRXr9X}`%6r*R0JG2F2 zY}0S|f#sL_oVudFsz!Scld!v#zai~f zeDRcqtunOJg-)B6#q5fHGD;9y{|am`#R6vQa>mm5tJXb`u^uAk?n~l8B_jjPmoh(K z1Ca_4L_Kqz?X<3A9M(UXT}WZ7cpySk)ozFw z=nQ}G`Wu;2B$n~z;BoRx&eUQ4grz|pFLd>VyZ|)UINmQDEmHyFfc!gm<^+*E?;r7W#t+2`r7F`&Rt7*A zOG#1El?@F%yYtc{|{a|$}$^AnDoEv4Know>X1%E zPjpp`Re0$oTjeK%`XPW5O=urW1!#Am8Z*cLgF+f}A&>xE5(q5$iD7*#gcn&V+%(uM z|F?=1xI`2q%-NKSmj9Xiv&F;-&J~*qRE`s6R@4LsGJ->I8zhNiTLS)uP$g}lgIM-YjiSix=Jml!t!DWGbh!$21M&D2EK=Gz$jp=9Ge<1{ zR87bQ-bG>Zb%6UUF5zBQ&TxSC8&r)MMsIr+@Y#c5c?K7yUbD}09a&54!|CGRFLLJ0 z8)j(I?VuO=3DduPq7hYvmNc`sKYBL3V4gO)xg44tYH58fX2zDvz~{t+i!8w+ri_Pd z6civ!DPy}hXUR6Q(?x2RZ-216j&@~z<1Qou{~Q^maa894NEjF@@1w_W?>6r$P~}LP zrX2}hsjthgUlNJXA<^y34?kNFg-qy2#>`|KGym#y8m=wuXsTb1&trCfe;u(rHHK%`Qm+d#w}TPtak~2!31H-5-nd&N8)g_4DV8 z9>J8PRvF5gFz!*o?DZk3&+il7a4%`g>uTc{Q#lK4k%AKNh3g7aSZF5G#rJlO_W-$x zdV2;zLxar&rzv+|8S7rx(?Clh1rzRCiE!>*3A zJ4q=N!}N}ok(MFSe--lDZJz}m8dsnnYb&+a#07MX{(d$ckRwK#vAp?9FL9ATz|TP3@FN?0{EAWL zSE(vzSU1xKhG6~A^%8vmq5u@(e?9P~dBXdciZhzRqQmr=;|oL~a?}}UC)d_V-ZpK@ z2ZrW*d2SL*Y2siSHZtp!wxp15k?QvL`6FD9ty`2X9-mJAo*s@j8OF@xXOZ12%cETx z^z6L9B>3E&H(5q3WM0ryDI_ivTZ5c`oKjH1eElidfq&0c@=j`hZJuj&+LL&y_?^`p zTp7)~|HT9}!xbBl_1$KVgXnI)nZW7Sug4*ek;~_f5R3czz8ZMQ1EQsLk{K+6=oY-g)PRIN_fjkow(- zY^J23cTHB+aovH94SxM+WLfX5WL^~o1VrB;gBmp#*|bUpb&QBeJubmtr&3%F`2sr) zI*(D2-^04cWakY14yE0*NftZrHcymrS#E$+xfXv?qR}+)LkTj{z33fC4ZZ<`_?Eds zCp9q;ZYno_zpY@h#wnWa1vJQKjEaFn;+YikonnCjUpS-c()Q#C{(_Lpvod+^e!dmD zRa~lwMvwwO+)n~x3O$$0Vl)tq;+7K$uvg&_r_3UFqPCt8v?aOKh%7T!um+tMF3w@6 z2q%`8$$lI13xTt|cM-vqwBf#SO#Ll_xjo27}d*?azIN&g0S8 z?bWLlpKj7SB49=doZ_WUDj<4hgzo)=ktLK0@y?f6KLRo`Mtyc$L?t=Yu*y`Yb;)Y7 z;yRD28`#NG?qDomEE*Wv-C!O*h;Lk`smuyD`|cyDt#a8b>E}X%23qE8qFeX126cf# zU{RlsqqTo6@?{k@5?}NV8Iy!$a+X(qmOE2vDU+jkqKIKp{w`YA>%qiK)E~tzML`DX z0|y)WqSk!SwA0i3dmP=b%n)OUHiGf<)Ayj#{D#`o`kLWVTnGQ)tr_lz3w!%VvB2#^ zR6h;6J#b}fYtD(p>bf@@5}2!D2AeXhtIy;-q0aUocIMBPCs*?@_>UEX$p0}fz{Wh7 z?sz`<2^X&m4GwSi@He)H<)HWPz9v4Rqw@%}IkT=up+H98y!!rywEkWJc7M3h0x^h| z!|FhJ`x?a>xg-RWLu7C6XK9ouFXWR6B~qbwoGm_oTuNFS#N% zD1bG7F|Im(B^-}R777)WJeZay2y0(r%YnV{#ZVlNl$-H^r2Dih>!YXb{23&kzqra* z(kSF#5(j}>7!>)i8=_Ps#rA$^d-Q!%cESUxiY=q1PRf^L<(D-!9}(B6BvZ6CU6F=_ zqKHk5T`e~mt^c-X)==R%613VPf%HitNs|1Ue&H-SzG289yr;TV;^SqSo4`4qx&4wL zFB%$AS9^IlovNuBj?wFI#&xXXWZh$($y&S1J+3>s)0|G-$?h}UAZc@0<3{xkU_t}u z;Vm=>zc zKWMgKF+uo(tNd}AHhySS!5zT+Vy=xE=*O`*92POU#dW)QjFA{b!I__rmu7Tixkfsc zIyOM4r=Fz$6gD!uVV=(XyPvpA{czE_;apY|KiqHV>Bh6@_=;f$>J;Qt+ZA>aDHvzg ztF%lWTA@Uz@kB1K-p<;2UCPC5%<6)rCwmfd$2lno>H@Nft-zYsY>$BgEJQ6VABK3%SHAWU_ zU};H=3S784g92D2Q<&Pu=_aAFg;VPz_&sB!jwUF|5!rz$*&&U9p{Uif8rOQW|NL-% z%I0A6;svG>2xI<0mjNuIF;s;-?TIQuI+T(fdCSx;^6@{vI(!w2wTojsv)Xj#AG;NyCs}!8#!OVz%Z=gM#!!Qmmgi>M|W3g zlv+>;8wL(2LLt}Q8yE{NAJNsDv2Jp(e!!J3XOD!C`8+j{y*CA?!#(2%ep$_4j(hI>1s%>2HGl<;0`c&ISt`q!>}1>>XJc zJB4hYV6QQC`eOOKKr+GX<*_@w)9h!2b;x_dsD|v%qoWJ7CrMHwmcf0c3KBXwCLXf#q+Z3Nu zzG1;okA0j~hM@Q|&~~_6GV7Y5=!LAcEbW!u5&W^QAON{Xe;PYayS%U;iNxlpAY#76 zR)bJAadY*8BSY5f6TgU0a*)pcT^5(=Sko3hv|mZAZ12;q`#8dcNbbYmiw#Uxl!OxG zwhHN}3w-JPX5WNVeFtNZD~Js_gr-3LS$)9>7mpYFbLL9B%|Y`ko6lEkm>~*Y87^s9 z4OEyJ`fAllziBy4%oox-jI)#W^cJLM^6s@3todNfSx8F9B&PCT=*|OYWIGg+F+(AC zITjG{9w>H{7^kgHuw(s=eTaw3)_%%9rDv2G*gS!uU7;0Hm3vKHD_GNzPFS`M+Xae= zh|T&L_#BbTr1eKr_w@k$I5%RGtqq@-be8~_;n4vVb9%ic%)xwa&9!HTvJ$YbKQ0Gz z;nr_%Q*mPB{*O?l^6R4&{DS%8NJWz0u7&sY^3)p`cvJjo=s40KR6b|6O8+j}aD9KP zhqX66>1q}>M?XRxRJJOJd5l7ZSz`ji2b$mnTrhvHKZQf$)|^L$MD3;;=H(%)!<3gg6cd^j^8~p|#W*0;^eD2rdP-U#~^Ou0%{;=^;Q{4et>NV1G)Dl8B)rK;=`ic7k zFedBb-9iDw!Mw_LUth6|iZ2Js=o!|zvXW8guZMex=S^AWpN}b|XViP1fYvN%gGB=h zm&F{FFN%YYgO`t2^|_1sKL|D{Hexm99T{N&qA?Q^89|nx0|LVFjB~R51$iiemdwiV zVXh~8DP?_F{@vtsgzsyA1bh0GRRZ$%x-Vpl$sw<##f^28lLPg;+CB9o1z#qxa9(Cc zy~7oKC3{AFk)=4!yiJJQ&6?71#nudTLzCXi`J*%;a4y%9@a=QFc4zcQ10NU>j=f0= zY-~;l$1HVs`C72rq+FNlUH@hbrKr!HyNUrC8HJK*DQSm$^{T&2z}#rM2L zoW)FiEPZEOYqQHm&a?UVHKfgfQJRpE>%w#uoCFiIE`L6>%?*cq@(>M|(k#fW3Xf*oU3uSB z;AwdhF1s#YHjc9$G&W9CQZg`A2PtRb?JTU#k1#IJ5a(-ZFdHD)$?Rc)$>^tt zrPy8ufX+Im#yH`1dKxQTWV?V;d*!^yc@avx1azn06ci3gb^-ScFX|uQ27DNYWxYZi zk8TSpQ75{LtxMkV?r+vrR4=QU_uCF^J#Ji%*KRE*8)^ybXR55^*^Br)^L-77!p-0$ z`)XjtGYc)nY_=j#(iN}Xtc&fHCdc~d>q*)JO5f1ClS|lb?4u3=YX{h2DOe^JU7a<8CIy~Y z)P1-cxKVT(!|dnmG&J04!3%vh*Yg*1cXZg*R~O1}wCmW~K!tnMPQJv^u;Yck2O2bW zFX5R#2ZtMS<(u-hTf_m%l)q6Fbkogr9|BeP)aq<2V4f|V2k^>;I=RY9=XY6QCyts_5b7SZ8S-{af=`GxMeF)Y;)sf*04z7u zK7{f}i(Ztk_AT5}(gtqIWRfNhP*^Kn#U5I~G5%+q>r2D;$}Y1Eg;yae%{}g=dxr8^ z`sMh`!by3LQIlvm{Jq{Sdt|48TAJ{J)c)1d4OE!!y`99ZMboJ!5HY*|$lC~j^rrf= z2n~uZuiijNWZ*yZTmWnbAji0?k`8#`RKExJd(J0Z1t(bYq--A%hwJjYzb@46X@|r2 z!u?o;yA*;7@Tg%v#6SG&edQ}r`Q1r)?S4oU`-jCQ4oJFLKnhAY7 z7e)YTMsh~UUr1kRc`ZE}(4fF73^6;&(X9GJsHw_jvNu>bB-X@CW zh|O9t(Y7n~l;SyRRJcbFjLO#_BZd*K&}?Qc7IOg=g|I43jQ2TDobot&xhf%7eHT&@ z(kOkq0^urXP1Yb`FP~ z0N17NObcwai4%^)09dxhd5e}IRCaIg;7}tqV(KH7@yt)UynA~SCmP&rH#E0roba@% zefQYa1@s#tK*o&n4ISc^UwJ@BU!e&u+CIKP9v4S*+=M8*x7?MHHXZ;9spu(93fUOP za+=+*!cF^4e|q`D!SzA%`SM6+;*sW=L4!4^TQZ-c9eliDfRSUbK9=~$%$FqP_sVoL zoVh>9!1DC!D0t^5$UDuO+&WkJ4?F%)a?3EKWI0$Tu)U>C)zb znP7@M>mat5Pd9C!_ba=z;?iF3^^bEHy3)$L<+*WEM~r%wIq6wUvO`#=bx>9b_2_R) zF7Ly4Vzi&9c&7LFygjW-2lk>DhXEB@btj=n6EWf$W$Lp>}s^oU=%T3nrSDg z=!nmw>18Rz^{)0b7*{aSROqYPb$CI*UQrfSp<5ga|fD3up`4I(3$BR(7N;G zk8BUBR@MoVxmEz8mt0kB7VFK>|E>jiq^b9-x21r5b>rhZeqh8em(}e6_Qi(2g&Ib{ zo)<0pLQ4-l?^vr+@6^tVcV)aY0asH5S>*2zuEKy3nw~RGizgyyN28WN72!(>8sP6uWp+6}n!H7{Pnz8|1DE1R}a#=H|yU&fQh<{$C-`%!w2FdcS_14N-Bhw0~>^`FYWX(T{TiB~wi{$Ex#I=s72o+@(FCW$r%df%-pJ7^}g+ZYbIHDE3Pgjs6Uxk3CxD;!o){+>0xqdkg zg&PuNECYohbtl#IPG?4@maZ<5278;uS@c}IGv{HjKDTR@H#5gBPu`oOL(o&2L5$+- zzM+ibP=NRgO?yHED#R2`v4u*wQ<`!hhWBS{JuV(z^WMI1F64+h$7I^z_@`ddSLjGT94=C4PfCB@zfdaf>5mV{)yqwHy$)cHfU-efCyM!3> zXaoW*o-s$Q#VuvJ{lI;@pNApf)(`1*=`BXjmE36$#Czt&4nj!sgjslQNk_fPY21;P z(l*9>)8b&R)1Xl`wKATCaj0>MFB>N;dw~v>4vTVstAswn?-j6Zge!0`IK~q=z9(c9 zI(R(FtH01nUP)m?qncVySsiq(KR@hj5`^Mh$_cgNu^%#Uy7Jy}ORWnNO+uB&m%Hk@ zpy>%$%c@Xhu(#IUcQZ0Pdh9LZqH;s6&Dv2wc0S@O-|!5ty>>VI7Xf?8)f1AT6+=7m5mu6?+A)I@<+b<(^ zsN5+NVzP(iIKme2o07N9H5~P+?oVe1T25<2@*7`sG=e36Rlb$Zw^=Y& zg6nfNIIBst_j<^Zk3-QDM1sboz5|gAR1Wu-JDH+%%o63ju0$Hd>oeuG*5(D-g|ufJ;G?| zP*4xJF3^QJx--7i=AQx2pWa`CQUyE*5L-XTnYUMhSc9-x2Qt*^Q1l7xjyG zpzVfpKVdR=P-CKYowOr)#pjc~a#Z)=TbRzZ!_PxozbIs8m<=2m1l|d}PG_Vd%Xw8p zbutSDi)DZToY@kRtb67jix-r+0@l{wd+UWiI|)2MH2#3G0uC#KeP+h&iaE+7TXf;q ze-?z0XaN>h);}Pi{UX0e%)5d&sYv3)P{X(!wOsXCOBzs+2jZ$MpLbx&2=Q0l5%+-FYN&3O8$f22^4pidcNm;rEo|9oI1&P zKQFLTkjgG$Q>9CevdrlWC6KD%fDr{?ulQ+2L(#u*$zgJJyuV&k&pV05GoTQ(X$|y< z43K6oi%xxZ-D_$ktBRB@S( z&&Fou;0N(Z#3m**{^{AD<3m_%!ThuQibNr^Fk=-I&@1G&!;@;Z`jq=J+|AMJ7e(FX zgEV3h(Qk75P)MMi8=TJ%Se@Ym{r%OClndf|7>iCrUGoXTvyw<`+fME%Bq^I2J&NzB z=Btgn6Wr+_E6YhMV?i%sy-}xD>fU}5nJjw|j*k-}%`+0e$D>kr>1Jm_)$GVP$fvh$ zat9SA?mXCR;Jd}5XTrZFYvrth!G)SXabshB{zRG73T9Ia{T^pB>>vDc2n#DoLzRif zp6Gw(L65#qH-z#noX{dBM;Q!jLyh7Ye_@W#YQ$Pk*codz$jY6BaWDO<8(F|z=A>kx z$|-0ZQILv@M)N1L1QaIH8k%Yc5%F=A_|}=14q)b~rRR}y*Oi7`&e=yZ(O{4LEr%_2 z9jYSrKw_aVQAf=MiA2XC4C|mMb!WIfVE{jor(dU!s0V(xyW}AwQ<@aV9o0Kmq=LXs zkWWbxs-EE@`}(_tR}tGNiivx4!l|G@YtnkoYg zGfBF7U8HC0;9x>!Zrf%H*!}VFc1K{F;Ed05rj9J^SGini3F5p|$LEs8BA;|N@18gS zwjJ;IQ$LCw$WY4hlerA#Yd~nV(y{Kd>8YdlX8IWbb>=$d4^$Y6+0i=JD^~ZW_;1M1 zAPlQF7_S$CM|Z*h(EQsvGbI3T7%FmY>vh9zoIE~gc5X@`iTzsi%*LJh0D99ZAgHSu z*5Kr@@2|kb#wMi-ft(VN{;Kfug7@GmIWnl?g)S;l|rY4a@DotL=592%~Z;Y$U zNrr_1ca){j4X&XQDofyYfK zLw?3uRd~Kiieh0R>?qdVdma*Jv?Qe@4AFUXcFowfxPReBdoofvWo8sP(-~Z4YnShc z)>r<7mkIFh?ZI9Y>p191Ye}7QF8H?ymhU3Kmd`8YtcFBFF=PVX=jElfTdFR43nRcw zks7wr-vZ66os;elBszl$2Z@?Hu~6h`R_f1ei~~ZGvVV>(iG%2yM~K z4RWQeG&FU}PR&zSQC5>`kFlOnVKJcosdVo_#dpQC!&d)u$5#R>MN(J$K!b>e>IYXY z$REn^mAe?3i#0smcm@UZYb{PI1dP;!Vuj@P!jD*Agg<>BxOTO3L3?k0#^!e6jeQpX z&lJ%2W-||x5hS)Sew4m<&b2rR?~8&WHMiuXpn$;FY_W?}+j_F2klsI^2s_c5nZBTq z$<4XZ?z-FA?y3EIqxw7?I!7n}V%DhVj0ywX<1#yW$y!uJe!uqCDc&-DA4?sM=?hBw zuC)T8r<|QO+pRx+xcefl_o!vI(L00{qr5ln&g*Ek_d_t}3l@A~ndHs){l5M_Fdij} ziqCE~Cu7$SdT7~j#rE3@yhI+h+2_Ze?sFIfd&@ou-kHLYf8AYJc~3kGWZR{5nuy^h zq>GF4SGa`chqEbouTLK96YB2AH|4TAVe$eMLm+n$F8TbM&&uqWP2rGV{*z4(SQ0TX zHeqJQZX85Vc;W|!=Al>xmYHNiZ(BH{A`wr)!dOLf6i9?2&KVR;+@vsDyCcgJ3QSIC z=}`O|9`h#1QPu!2L%bdJoU!mHB)AC-mPF5lsOX{CJrY{SL=XE&t_Zq<;{N^JhX=6_ zKY^cH;;)Sgb)Y|evW^=Tf4hu69u5(X(ZeFFwk*a>^mZ+gEY{e z?69qCY=@PZ`K9lMp9Xygl>G>Mkm-yM$OfzeqDPvJ3rc|guh$=`=l;~95uo|*2ix{F zpo#BmpC-nCMkIT|3+Dr+G>H*YDo@djtoYWY3!r)j2&$&qK41V^eV;#nX2*pmiXEn_ zZr8*7Nrd!IXH`Jh7fjo{3J8pJ_^h11p{mpLzYWyU#Ep=YDu&sza~u10Q)EgztlACP z?WdYt{iB};G^&r0B)jK?(oZ7;niUsHB-_u+iEP=oPCI4-CyO!1Y{!mEaA0gMZ+3BYb@V~DcNWkvxjkBPuK(zl{YmXmoo??Zpc@?U0WYA4 zjpxwguw_9xNBLn2D zN6+)-`~7_Gab4$iUgv$EL((gAeH5l#&)~VsN|TAFrjhx3hTK!l=Kfa$*JR2Dns3SI zLRtQT9!7=naC3$?S;)-HjQ++WxW!dsym*}iZqH7B^{RJeag3Y6f8;uc=D~Z$Ot{|J zQ6DQ6_p7cL-(gR@692fjy=z4@Oy7wacgQAqu=$tIH3oT~j)Td;)h-y@{pCcb!$;&{t-ybz4DKK4zthFp_UbK2Q-n|>X~rAwFg zg`A?gwX_BI8UNg3IJY*;!1W*&44GCJ>B8!rzZ0*RFx_AFtU7VH?8IfJjUGiY#2@3y z69y01N87_w7`J3Q=kH3Hi@EZQ{m(^MJJy=#!gU2U>zGX9Gq`v;=Dnk2Epc^#y~GXN zvkM=sEn@|4jVa8Ib0 zTa~uX-lg9H_8T`;3t#&D*!@qvx~8JyhS@GpIgB0rY_I_o_R-D1id85tdzzbzx`OHl7K=k05CyBua9KLjeYgR=IeUL(P$v1@XSFfxZdrkW!Yj|r zGPS=9inLjMlVuO9p3eB2Ys~c1UV>unZ3-CVVVYnmVr%=*NMO^g>i+N7B%|GAZIVVb zlwlW!=#W2BWFq;#_Xhky*4{3za8i?EgOvq7Z_tsYc(h2S<3zcm-9K(D-t`0R{nxP( ztVFL;(jQUHSbS zO2^cPRQC_tCAl0%HV8Vw?l?A>XDydc<6Y`O<19NHm0tQ_R@2UEKF5iQzTNa0SMW(M zEC?9L;38#sP{MZkS)zOQgittjmD!iVE~o4D3^4`0&sTA&@+X*yY=+I;`cF*);pT z@LlHI@UN*Gi)A14aDH#WA4kcNC@lmswWN`*T)z>X2lG;%hZ*D&a!M~(jLmj*oRI3e zN#lLjNFb+LR!p^Nl((3z8a|5iZqadZd!$}o)X>PN9!ogKqiATzP7!?4ENt5G-)n!# zq0MF;2Yxvz#^c$|(cH^l>xnHROpID_Q+?O+sPl>C*NKUrR=nNq6pUcR4UbGDVI|3^ z#FP|&+*sSXBZWQmlD3%;Gw$Goef2plt%S#%Y(c~bsCQzHe|1fC{eYlq2z0jE2VEEWrnbQ5>dT>|Gs$lt#xo^O#1MFwMm5Bb3Og18Dr=s+x3*H3{MWf z%+%CWBtvW@q?=zbb?v;g=r?dqQ}g5B!6wAnkAse1T3RaL0Mhp1hH#0IZOuFDUFAg`@zQ`y^puJ%_|-TLHGBg{>t#y%Y53Sn*2l;F~tyO&Ik>B0{`- zdAaQ^pQufzict8Ek?Wt$gkTm%BQ-1l@7_j{z}vUih+g}CTuZCv#3&_F>v2xF;>15T z`RGN3zqhcsNETzy$9qn(`(x@eVS|gX2{+kazs!L}PDF*2_H=hRcXf=&><(1I74c4S zBL!Qjf9I(I0Ra*LPhg0UGm*RBnAfsMjsMkHINJ11u%4fuYf=nfEX~H=)S`S4OXtJ^ z`uNn;z?Bsb%*8QG+mlOz0s_Ua&Zjh7UpAdj6IBg?>FU|DJN{*p@Tj`~_u!y0+=g;G zMuW0MA7c|{YQqxvlNZ$v*w;Eb|KoxdM8(z8+)UP=tdMFGn>u-O7;_*Q!7r)PWAU1# zloXI0)BTe1aRGcXv64Qy?@rnf`TcLYHo*wZ-vxk7u(5aem$$yZ7U3RcULPHpHUvY@ z+CsG*Tq+!NzHcxxuCA^~>`vo95hfE~z1p1KavA}LxJ9IK zvJsqABSeX9g_Nq!MC9F)Sef_qu*k`9owg(|aanFr8`&-K%3A~yY-+JT&VFV>RaiaZ z!*5C{T>3+`T$ZO>UeNgePbs$)(hj(%=uy;F7lx2?B2|cQr6?}WxJJ&(pEu=j?oT?n zLm5%Z;{tw^2p{HuMJ&Nn>Jk?8-13UMwV~amzNPsoi5{;VPbIJO0FygWn879N2Nt6= zaEF(|p8P00G=RhgM!n)5Z+5k7IKOz(ONHwgEP>Mt7kG$6-AS&HnV+|enSjCTBnDX3 zoYdbC1Q51lx5QxuNZ0n`57Jr5wfzoJfw(6*iNgefhqW#JvnSx_aoNI&xPQhLFq_ zS=SH#`O^y={SyYbiU$?JA2_!ybyLc{(JI-9?fHifKdjraqUSrtJSwUW7fJtNt~AeFzjn>aa@T6qv28Ms@lDu1 z%&e?t*t*iof18NU6x03NiB}vQ9c{~ZYT4EZR0$PaDP;6~!2?F0{Zv&T|j# zCfF-JJ#v+S6kQ~nT)6Ok&BN)T7Khd4*-Vr2&9DT#Z+u2uyXsPy{->I7;SqWvM7JNX zx8c)$ACT)XpsBAPhzN&Wx`jW-NJaivh?zD6p0RYBT%`Zu!l&>LJPVlAH256fyLFf7 z`0s=udwkn>sT~_BDd$&ai|XOn;(~>M@`G4;o$LOY)nd>`JZIq@m_4?)6t_M_MZ#3g z?U!&&2DjMMyfj{i*hLT*ZS1(Ma5lqEg_X28d9hPTO+P$Wmx-YcOyNRc70nYG5BKmB zY&X1Cm+DuaKYwl(!?pfUwaMblh%_u(W-?n^TYcbihxz;`nUUD*(cd0R@JgI-az1CPC@6|BP{~X zqk8ZOSJ&6yjjadv!lC1Oiyi)HE2+q=lk|u6rvQAp!DQSwCRT9(N+l?8E8+5Y84#<2 z?Is_v8Zqu0@Z(K0a&dVYE_~U3x>*Y#ItM;)#OSi*`Kgc`V*iK5PP`I{fuiqmP7yJ& zOza#Q>2Y~fd?*bO;`%>_E=`6N?PH}0+PoskO2hwQTtQ7{`kYkdS?XKFT9X7orx)%>{ z$3QO_czCdrvUkR87?v|YTm(Tth&DeEQ&0TkkWGho z)08(q3*lBzjElPslhwA8c%&jMz?`MY>TvikA;Dj3e{cMZ3$EM5qX=mqU`JhgVm+MX zX8gX8su-0BERT*Mnj?$6dd0(~^3-TdI(Mp03edAyMAcD$0Q}5Pf5Ro9=L!P}mOv%) z(Uh#V6!-<+^bhvuws&tU6;UN9F;`Dd&`@RFg~(#w#C_lGc3!c|8X`VC;&z;6gR#9f^J`LK{pq$#ydK-u>xo@>z_po@1k8K6Tkh zDxepP>%t5@Brv%!07QZT`=+Z~{rKQn3}*kUl(b%K`tq^a_7FYTN-}asInrUX5O1x{ zPPI#y_Tw&jK+=0I(shVzHe*i?&g_*xecC^&?s-;LDC#%xIzYZao;NuK82s&u*HE50 zf4*F>g0-~XbH#yW*RGMr`mjmA1TgdgX>De+iGGW3xyjPsBssFxJ_%40i&$hHz-pC9 z@n9e_68_VD_zk!bMRWC7ZHw0sY+pa1{lP_0EvZTEwx&w?%xq1D96&p40q5W#UwNC7 z4LLh5B?V;3%ajyTxL>J+3*Adg<0meq*l?IzZ~%>8zL-$TJl3S%x)l}3Q(1S(vxzEk zuyE&1%eix$6QvA^x;AcoyWFNCx`;bE)+`1MV+wyDA-mq=z&xa!LIMXs2lgZj4s{n+ z0^XIGg~bH*JN^akQdruiD1V5z?DEMSpKr?-tA(G{(;solABkcqcU0mJyL!%jbM;$< zzwUT{e%RUG!I!}-2yXGfjb0gy;>ymOG`)o;0CY1vzC zxFQ6`ZBLY1Q(N1WmQcx=R?{oi)?6V~wjf_n6VE1Z1HlH5s7>ohplFDHdgtZE1mz}? z{KJm=+4ZO3Rx9k6&sf?66~u}u)vb$96k%l*re{w@{^_wr@nnq&QWj@5|MpsYrq#Y= z1Ebq>L8}1=`&>N~kv&fNa0}ZPsvLMFLeme4Q`q5`3IHFFCTil18#ivcr$Dw6ruDI= zp2vLu)5KhLDJMUv2WnCyA|y&-0~KTX3E>%~2J5x*cUPK#``HJ@rM`#+!mOod+7;;$ zix;sJFvIlA^HVU}^q4DU|I*%mwMCqF<122}XX}dzy$iZ;!nFX97~een)u{aMO~aMR zFbU*cj21x=09h!h$;ZII4)kz{3fFv7%!*gb5HU;x&x-#S2lGiv#e}aNM<`r-FX|nv zVlU%+bM*M7RNE%|t0jhBXAXC8c64+QOoNa`mIiUXhkK>c`SYO#cL%@*zV9d`!*{Ly zK8hsxy{be>h9c+k=MM`>StLGwd$$W4*y(c)fl_dz#j{{%IQgaT-cLvMKVDHIMG>$^ z!y+|Cg#S-t#K+TQRcOt_g3#)IK34%K_iWD?Nh6m^df*p+*{L9@**k9pQ3sZ{_*IsB^WAzAd#*JSd0QlmMg zYEy!ypQwM;Mx21{Elghr5+khR;Tj14z6>N!Vq;5!RH&9gZv74w?5x*%xsL*OiNVHC zy#Mi8aZK=>=Wv0u+qv^AtE-WZs-CNue-S>o{^TZFSE?_%wEJV%uHC}(=g*1aX8of4 zygv(e*X7=(nHV<)4Zpp$Y4hf_cm8r6S-;7MPFdm9?Z43~Lc-y)MRJq*vmx?#XdlMU z&b!QwA2lB>AV$o&&YdI84GqNrmB$O1;9my96Hind|E%9$0yfx;d({u?=N%uYi@E`t z$~P$J^4xgO-;wsPgcCu%w&E_}JDKG?q;V%J{=-`siSg^@$O9;nT7p^DO&rCMDUA8d zE?vrME7nA&0IXAmm0ogM8o`eQj=s>WTAuByhqc?ulP8J&H)3C}Q+~uHhfS^$h92|# z%gf8DnFMYTem+{}Cg?MIcoeI{Xl?uU?R(cb|b6h4kX}pgG$QFh_vJ|@mn%6=^=yDya>+~N*amQq^BfOexpKwp-aWx<-$lUCRQ_&E2p2G3TOD?|V#LY+!AUte zIRym;Ld>;i+f-O)*`sa|1_u9FS{jawgoeIVNC_w%Uk8P_GqH^pq(Fwpra5NRs`CSL zSoQutLU5*}2~@gUVniFXZQHh_*w|bB)nS>Qs~&D{Zch<0w#m+M=;t{S{7wAJ_wn#- zzFR^TSeeiG@81>%H~M#ehu80`sgyz-NL;tE0V*LY(Oobk+)=S-&z{>_hVG%57hyQt zYsBUs>g(?BFGu+<;l3cU`Q9CDEdloh#~O2C*!R*LIdUX^>J#1rbN}(<$Cs~P{|r%= zJ(|*NlG*mVA#UB_g1>ZlmfROKG7+^(3JPlxO}RNZZascXe*xx~h#YH^*!{latm=d3l)fZT(HWR{o)cfw> zz6b+uMMcHqN=lQcAA+AgWul{_V>)6?KN=Iz<@1E@_H}AN0MQYM*x%H9b~vE0Gl1^ z>gwd#5nrgNsCwW@ym#L|ujkMAQDB|c|_6`n$Ow{nip`)iij;~+0k(##s(5B`SN=m(b zeaEoh+S=ME@FWU+yjeG<%nKL(@S@u?*i|HVxNwe>JZ%Xj`%f4WmiOEHE`Dez{QK9D zn3`X>z;OHa?HV|kE{+w&V0TOvuiN$pyv^FRYsH<0naHSOBHiZu!~PDpg*@0T#Tim{ zQdRX5q7D=FMiG{+_+pb!hi>VlCO&_@KBbvw=2Qc7Ug(S3egFMqE%#{MbE-SW?wCqU z#iN**dSdg5kbw>V_1m|%4g+d4qaCO9^jJ+R9$bPi>NU|y_$`G*N7tbkK6&)$8eXk5 z`qTW^e&hsWJoB7Jluw@Agv+{t4qzLP(LR3S#92MP8m#3x6x`#pvzOsSNldKqfj#~G zirBtj!a%nAY!9a(Ehqq_?Wfxb*@0N<(B(e6eaDVa>4_w@sI)KZa||9U??)-|^Vcs;kmyAo z%P7+ zM?TzNsgE8#Y7%j8UeaaUXzjXnRaI5SOLG&wxn7#5$zUxnOeZ*xb)sDtr@h|g{P~^M z*}C)+t1U$<%~l>1h-y6}BdVwi00#E$-;Z^)Z6h^PG1Bsz_7Upcl8>>2&R}-1q~BcR zbGdnQZ&)?6pviin@BmOs;{Zh$q{p_&Eunw?3zUf|;sfT>VP)wC{pJnx^Ydr~H5Pe# z-ZF2a)%vrhLuKY$pr)2q1qWsrUuWU}MXx2E31;VM?xF2H<(x~fcD?kS*-=>_aFd4J5r#3V(B>zZPvE0c%?ZAL~0 zK7>$hu)3FuwcDnfZM_?r-OJm1YP@vY6o{4gRcknZXCX|PX12XdO}#YUT_W~1>Gf+q zLNXI;;+PnA*scRQ6RJ(&*s;OZjJ+H0#;MZK;(Obye&vr9?%lut9>&IZ?FfZV>JwS_ zh3@D=P$kKUl)%FoY%=gpr*KEl-;Xu6Z4$M)XtrwU$bV7UZGofO^u3Oz*oy7q7i}qjECP|Y2u>L)YPo3 zt=$>31|g=^N|buzI`@UYd$_nZ=6BR#d11o@w0h%{|GLT)xK5j3y}BVCuivl%(^ZUg zkE)+m_$6XPTZSoRrg?4r%a+98*vctj{vB3!5Muy@HtCq0srhs2n+ahWpHvAkIa6pE|x*eaGh{Mc* z7s)!P(13N0Oo}BcpYJrnBw@b~`TJ$FWrmuJ2#e8I%dD5n+gVsP%OfxWmf)iS&l)U5 zo~m$p%;tB+o#Q!p_%OG$G;3y?3Whm4IvQP~Ohh=1EhK|y>D*Wx$)?)8IUZFh9)*u# z3jm*m2#dz9zKjaC9Yg)?R-%HMXxW85d`OFWcw&0GA2W#JU^^RIU0;wKN!g72dSqH( z8x*iuhB&zi3)q% zym7&Lw38a|esA_Bb<|5DYw{IU+g>WY{THN|{BT+NC zY+}csxa{WG*B4R&ux-an`_b|vF4qX7sVplrY|@pW^-dzd|L!c10+mmI zpX*->RKNI%ly%n;c@&7NO!zc34Grr-yI}Ix&{k7Ou+m_TdwO~zrJU91AGbiPVD4RO z-q#R!%nc)+-P~ihAN#DVtSq)^ux!!>Jd()k^XJYzjE*L>4Cc7((ev51?Zj9)*1^(}o5uLUJw}7r6G3c9b&o%X#x3o<+6T4%ty&7^zp^c1y-#GoW^7pK(71(%+-sb>z82mQ? zEjMa3eUo{8C5C=;`Ssc4WQ?~X7myZra8ol4sA^p?J%p=wOOxckDp98V1C=FOWcYijlh3ZgN0%*u*y!-fs? z3=AY4VV9DCTUb~~xT~8wy%i?tq@Fq(8=Fa$+}#rSu0nUf$aI&Yd#?}3JLXY+KdRL^^yP~t+4mY6 zjYsZMdkwz>shQzCW{~kUHdc`nOXNR4c4M3p@@eu1s@I-D*B@COUU#*HAyH9>z8%EU zJ9qBfyFncV-v|Qx%gT;(?%g{vHML`!%FW$9{gr3>g%m|@c{@8n2LBh@Uhjf}wjq+l zcivNoljjQQ85;8U_NGJ@F`?SD2DlgyR>E^dnkQ8G^^wC&E*Nm)@D2fwXXeb7jaTqD z93^+tc&`goDg}HXp6zFo)LsCHIgb7afA)+8uksnqa7svoU`JK-@o*M z5~z&lrv}7?e^((%YaL5z?j0NJGVe2%m#08q(-dP%*bV|;2wppXUJ()(>|RoL0Ksc` zDhkN2`}@1`?uIQ4A`vIZoE&nBQfqm2SG}_Jl&WfhWz?qi8#g|B@>-jk3ylL`FukqBsNFM;mp5V(*Cu&o9?U*M9m`U42@Q8~F+CUxtc$D^Z7q zzgKQaS(i=D&N6035@_h+k~0^coh{bi-wy;$$HL-wA!X<8*tK`oDf~`=fZ&K^q*_uA zTU=w?hCl}!{00Kogm8BIAxRC8Jye@E(eCEB^|F0$jbV?g>gm!Sr;s=yH`2-)yTBZc zzd?3CiX?N;Ha}3!A{Itr8qgv_VEKh=;WEe`u@^V(vVrBwFRLQd%}vFu{;&$r=MT&A zPHwhzl+=%yeG^gAV+-yVfQGSUE>!!NSbHMzLZ2C~!gd~PNsqC3%@}zVJ3?#l@egW{ z^jN*!f8I0F)B8L~zdFeDA~A98JDp$Z35`De6`=DK2Yfj!4AF-J{(Ygjor$S3!?fbS zfddpAsNjH#;+riMM-ow4d`Z)n0+O`@n_lR~oYJfengL_GbDfN&^QelWqX@W^c>ZsI zzuKS_fS`f0fi^okJLAvxqU(QchJW9Z{%>!CbYQrwoQf8dDUO3*%ond+t?akYkmeK_ z&2V1G{m*71f3TZ6N%N3v)BW`&-5lZv(w-*^W~7=VS!YM|_L{6Lv=mLQu5_&)OKLI} z3a_lHic3t~2q5@UBjK3@&w8&Js;o)f1--fSD^d}gz1};ioc45(i~^GgDTza~9abK4 z@Y%Cx_bOQrO}lp97}%{AEnPJ?(Z>v01^MXs@#7C4J|rOlJA&|bCREH>I<>v!|7@Jb z6YQ6epiCGP$2to@|KwbBUQuJ&oi3t$ta|GCpbXF=cs;%5UvQ+yFhGkzKc=pGh9?vi zmVxWwej<$|ItKFbetz))a@4Hi53w;OP+9tX=TO~Fov&5bk1&sXLv~xU&o4*G^0(Jg zv$_zH0aWc~hP6#hO_^m}MJT{XDW$iZy?F7Yy}fVOu~B-78V{8ldQ8lev;V)!VoFCm-#%KeVU9OksYHN%wj zH(Y?4nJG|7^7^0Sh`B2(E6CS(!AyX&LY=i&M1(%1>Nr9MA`0LIqJs-cu*&M{zTgvV z#vtJUGlAWdF%x7SfB)XJWy{r>B`XVy*yqoGzH-mZ$zejEGdKTrWY%&|`-p4I8r3Lu zED*6n7Ki_Px!Mf`vvR1?#?RhkuK^2yT%*5JyEXYNccVPiKJAI@mDrB0trtH%*zI3b zB!}AW!@{2wR#~@;OLcL_yqP2&Eiz}?tVC6~cAy*-p%*za`n1*@KuRO|99gQ`+P;uV znh6^NKncktvzMsGs)*niqOsF{o_f(6L7aM>Ba2bOKMBR z(w{(cV*0MXl+#+|oeC~#1N4ap@f&<-Ray6ZzxAi&zJJ;j&(Y{XfD6MvjJKe4xKHGP z_>#t~t6FqBcAQdEBkv#@p=kOGRVBJZt|QSk6;<8N8GSGJ55%U^(~5v3=XfGIyiwl8 zHlF$mIz1E7?LzxbX673$O!K1*-2ROg^wn&K^z-O%5y1enu=&HskB~<1Ui;jW>goe1s~ChSbb8e!6LsyGI>BTAVG=UiguSj9cV-+BY$0qTkJ?gs02Cq4Zz zSPtZ>UI2A~s{_2e$I@GxFv=S27T5ydlQwQ;f#fD3E$snlQU67IP`l=t;mDa8R&7EQ zn2Z~!=LxBeWgM_u>ItwRP?R?$8F6t(Yx1mk-gnZqJi93t&upIAhS)_K(((zN=dN5C zJmlGh==i@sIR#6TuSLaft8kTCXP?#3xLs?Wpp|X__zXlI)IP$H@KVZolv-rvr5sAX zt%_k=5D1Y^sAy=g9PBXTH0-h}NWvGRx%l}{f7faBIUEqBmb6z$=n+c24+RC`GB-upJ(mPSxKu5VwEs@RK?dx5^#A2Z z*9XnI4I+}yiHZriNy7E-0p%Buz?d~w0k3iRQFHIxM-lkI3uPB*0)StKAFadx)&dNm z>)c;3+qrkoo}O2lR$_jjS5Vc1{^{xKD}`1=9elx?HY)%pR{wHNN#^rBJ5k$!rQZ`) zO|<7_4j3&uCZ?0w)*VmLV7DolW&o(wQlvC&`!sYV6D#G-n}?vj9vd9R-oMY@uk&n3%vDbJGl#uol*wav z;1-a-@<2Y!%m@F}lIA&CV}zL$GqWfRe*fbXk`|bZ*jNr^jvU89I&eygKu=ebv5(hN zQg%%`Rg{%sBXDe#-W2X6stLCc(6!cU;3D zENnv`R!urHV0yZ1#kzjt{7pB`-^8K434MM$d42r-OGlg>07zo5U3#;A&md}E)MlhsZ@BnvjS(93US3{$=V$Yv z2?0F>XpB{~1WADr8f5WyW+*{G1XGZ(z5zT85;|xXB<0Vb*bn(7uO$jZ>B-8ypgJ>A zdc|t2>xHBbIDwX(os|1?U@F@R-Sb;X1%k#0``fo~TMON#fJq-edGhGt!(uRyz#+_F zNY!QMk4r1PWbSHc;84odf=myeAzUwg7XcM?Lrt`duJos_E_T$j5G`Wirf$C0wF#RJ zNE+h|d_gGV;UBpn0A~|YcY?>1p!-AsnVg(NO$%ls<8nP4VE*3CM~VlWrg%-+CiJOdKhY0)I+`@@U4S4E^a#eei~HrX36I9M4V!ju$s)NB41a z`*&U83HZUlO5>fAq>vDD@^?y=S^TS4$~8~6r=7>BzA-L!!6-sa0paC`HrFqfS@Nd7 zdqS%AnNR%SPa8hlG%1f@4LKIhq;2Kcth+^FLUns&np3{M;|kR#lSb6!U}XvM8SVP# zAo@^;J;q3s1#-4464IQq2;0U-)tzgqWug%SaOt+E?0vN+?Nda4>-awCe_z*>U_b zIQrwM?*LZT*e=c4Ax{Ah&&|zYt`BU~BpL+l5-5ft2Eu<8sg9kPojn0SC>nuQ85lr> z!Q9+j3`|p@`%T%^B?o|ueRTzdCV=8L0{MsN6F_b3m;89|#=aVM#*}&URWX~p_BWyb)e*kra+5o6U zlZ}Um$7QmL8}#k-q@*bbc6~Ks#r$HIZNVU+JA}?7;0G8NRY<}K;gOb>rhV-ygl`v8 z`cu*iO^AM9j~|NBk~}Kn}O;8ANdS95r6_A6qN~j`1RPZ zibwwt1>^kWCF~Irb`iozse-vt3Sc@)WE#oIC@a1KMRTLRn?rWD;5nkU|G;*RZ1i)vCjwhJsx0#t$H8o~ckM`Gky-YuE z^5RN?eNRzC!x_?*May0wqIMq_7is)}A~-%=CM zyawPPts5O37VI-{2=@tvgZ)%;ml6$orQ~z*EkpBTt1C7S*kt@vxcYp#qhsH_dyKkN z6GF8y&()itqfqf6xWryaNk+>90dlw_KMZAsPgX|~MC|w|)CRv0G`M+qUX}a&z>{{m z^lkk!G=%8}dic`Rbn4WpE9MQ4tIY7e<7GS4AI8K4VW?8GJ2}C}LUd>%!V)fhLx=)n zMKS~kM&ud*8GbYfhyueRK*VFJR!qY2H!bBhwx3XARFsw`lm`U+`E6!%8$TMWAufh? z07AL0W@lw->0K}lWU+%TBNQUW5&+zO^oIl)4L3Jw#0r8L;uBsS#1=*SkT#*Q5_63` zqw-9I3W5&|bSt<IfsHN?S0C5gc*moU%Lu?aVSC=i3<8r_FYTeohFU-?`ayqD01(k>6t z*1v=G29!lAPSD7#GOmx`yb*{MtR6z8fL-|%#|`Kh7`V)x(NRNly&BDXwC)I{dXbL- z?^uqEJqLlo_}{<#_U^sAx-=kHtIQFN-h`tv)lvF}-+ViSPNJ1fU|;e%fa*$cWx0^e zNJc$?q&yJ7%eOy`L?iSSvH6F)ir5byK8$^H_1kM>gcVZI(Xjedne~Ep!TKSsL14tH z+FBwXFxmL(Ldqq?H*B#6kPfjKk4b$k?lvb7A$pBdP*B6sK;Jjk`^(=hxyv_>{0H3O z55DcC>0IlM&#mD-uKfPOC|UO46o1u4Z7J?9xl52Kw53=BXnMPLnK=E7 zl|iJ@U@YSWzk=BUa{(w@f`vu3WlI7ucDzR!LaTO$$p!>q6Z&*eC**~&B+%{qf^rum z*zHew-_aIWLVFdbcbNk$yDCJ~o}m^~Mo zV&?F^gFkXJ4~W}jx1*`VfprH>(k*0^rKKe(bO|AV&5G^-0+M`!FyNF5UqN4AUz3S8 z$DU(-K=6c&M)4f->>0+P@K5xsW=ky68wip&CdwJ1hyr%q!N&FgJRL$T0?=e9F7me( zho*-%Q`>cRlpOLvy%DaWq@ogV@yiL?RqI22aZvJ504#)rgfzX%an-eGN!uOcv2YNb zPie?U-(77{W6ee}jNVRbj@@2B0>(`0XZDDQmX?+|2-rS8KE)$WXOUt6eJ|q3JF5S9 znxmjwuusbC>w~~3Az2^AV1kk+-xMBRg`?aTbM5!X#l^LbI04}RSp!bKdl_LO6b@)! z8&EHH_sonXqAe5#fe>WTva$VjK`-P~8ugADu)mSVx67bRZ*}$}Bba3z# zn%iI^${_aH-{*DX6}a-+wNYx`AefQT&!N1xKh;pB)h>{Uq;yNBTE~5=p0#Q*4cfSVlW)|7uqA=F-?S_;}3(fk^jl=0?kat zh6H7#TlA*nP0ddu_n-#6mY4LuxD|hLro`F1?;SE5vOSg7_gR^ka3YF9#$=D*x)8YY z4p#}#I0-&Hp_|b>!RE)>NO<#R{K?V*(TG|SK&@H@K;hsKL0e_objeCO{+35$i?sPp zJYcvL@k~2T5F)Z6kv#|W!+s-L zUOZYl%OXDB_3!VakQcfZu%XBLWI-Ha3rtw{zQu%ts2&MQ@$y=K=gyt7nlLou&}>mm z_>45-8xZgl`4;E6j*8H1x&Id(-n9=OJyP)SD4feJHz|Q!6iUtsS*oc6cvv(9A(^xO z?!)sV2`dKN{=S)VBY}sZq2&wHL%MvbmjsxQ#Poq0;FOUQaL#W9V9cugfnIINGr%e zK%02n9<-V)vc8+UziqQZ_9cusHrO~?#(|xrKmi>Xzou)@5vz}rW&bP+MOD?n-||^H zsek`8`{PSBaIFe@Oc4@mZ8pYveSf&;7!c##i<6)6<`^i{QlJ{a&QsqI zhEGd+dinrPdT@w}KEV)Rc<(`A13ZPsAoP-7UtRDe#9}E^Vp|+vw z5fC`z_z`-G2iTc?{r%Ugf8ED{4{RyI@MN~tK&hYhETKW%6IqEF34c>_VJHfE~Z`pg5#1DX(wW3BLM-k zO7ISJqfNQbfoABlPo3QX>U2dVtaKVljnRb#BsZJqThH~^^w zJ3BjSq^kP*0cbmm&4l#NkJg!IYCxB5<1w`h+D<<(Ti6d^snOv)WogL+nkzo~p&tJL zPUb2lG`)W3Wcwor>2wq6Rxu;3GVM9hmeX&+Lp;N#;P=1E9Gli+HI_q}`i z&1!r86b`!6d9~PtZr9FvK)b(-K1z*SDKFbAw$v!Qy?h2PmuN|(|M5rj;@!J<&7xqf z-sy}S$s{6d01&hf5e7ZKTmsF@h3BCEWIBTXL;;6vzt8+Ye?%vhiZCFxbCBou48L7JX^-JVcU(8Vob>W?m zItvle2+uUdKtkTO3%q_FzM6`SwInLs3 z$e%sc*_r0d@ZYr(m9Jso#>&G{@0r$Hw+Lo0>NpGk`zt-_-UDIthJ`H=lZ>NXgS|NNAVD14>}e)w>% zd9MX8)*#B#jdy#Vgo#^j0fGYcf^Nkr4-Z+KGqilRdSSAi--KS|NnpU zq+W!SVV~lEWR?H-54{W!S$e!c9s>hLZY=l?>xz*lJW5cV zFXdrnf60$8yQX;Hl&e(qSoSVsnNwqkqW}8IW~^E>$qVSMaqgTKYCjx|K!(5S(KX^^ zINe1Wu8Ejt3@u_{`l2{GEf^Ru$~k&zLz16B2FApwdw8s*e?Li(sNEZ#ROdX2BG>PE z)?-hm-BX?LnD{(k9fTqyY>@X(x_k`S?%-7s&xSe~t-Wqu^St)C;9I9IDRurekC2qv zv`LKhOEj`%zf$K9JMH%EBHt6;){lm$_Y_%+u8v&(c=IvxBy^*=y~?mg?*pTSYmA3# z->7{qIbd5+L5Ujh*s;n}&HEbIabHJr*v!2&T1B#UTV6unyHfdpo3; zr*G`)T3XH*cr2((db@BX4|Jf4Tp>pm)0!l#iiS$WGh}LWycI;NIsO;F-*z&6z zH*7#eTAp(J&9$Aaj+~M}kr^p&cLzsM^_q)#GcNKGAwE(Ik?a9D35Kgw_NOjVAwELs zpb{e+h1wQ)<&?4U9?SrC6NVe#UJi|IoBQ009!E)#HTeGR3=E&}`&E^dIVtaWk=r0y zqX2WCaBjtZKT;0CN{2k0QVZKhoUj=@CU44dBXvL}6Wi~u)H>#Y6QSccbOe1^IcgL_ zj#*|cJb7PklF?9b)0t4AkVc&ZQ3%rSV}1SDlh*@@4wK5GGZNqb&VxNWW@Oa8+q!WC zr+`c|EN#a3Uqznw`J0x*UtZMI3vfnKdXqyBRwWG^>F14{;n?XhZRei#W6j21KITh0SY< zEwd)kO?!hwBjSyW^$A+u!SHQ{LGY8W0-4!-|n>IMz`^jP%gr z^4ZLNsoSXOUcY`_UJo(JnFI9Mxr$`-iPd*p)ep|$U%yTV1#pnuia6jY5%xDU$YYbC zWUKY~@!d)k0x)v&o4_`TqU8~rwl~Jdgef79$62d?_=udG9P?;T05o!%2dl&LOR_CD z)-Ym?z|9#eh*7<FnxaDDy(sy8(>w9xE4ntFj8eSnA<@vbt;nx$dx zQ5zaKoo-XAD07~I(ZQJr=NG@{Ya!eKPk{4Z+BGNb^E)76gZ>fEvK4AEt?3Ty>>ayy zRgd2#_$CeiCDms+-YFw43i9;Lr zx9&o;`S3!nfq-W8Ly#^IbGw%!JS&$kl1sIa1Rf733B)ZYu0&W zTL2nwg>$`Ikdo2BH^J>FsFynj%{wk2aG|v6!!fS}5SiCQDtG-`3qUTZfbt`&Z4Hr| z<1phXWo5d745Q!T&ZDfjyCr`hAQtq5G>)MLupCj=6TTDo8wCjIhP&Ljq3;}nkatQ& z#SCEvs}C}>R%l9H3s&y8|FRW_WU3E`Joi8gLC%ii?-;(`folwLy6LH)=|*xC4MiO& z6wC&jBJZ%V;Bf)vogp1TsVxzVi6gHzDG`n>XV#uks~JS@UmD!Uwd_~Lx!Qk zg+0Co%n*d9erj{KZrx&NL9rV5;e+0iGlm{2vUMy`J22~qMthVJG^k(x#z76xDHJ&G z1Ym35{LJwmb_!Rxl@DQCB66&M$9nN#qbM}~sBz37$3PpK^llK=t$(-S{<3oUTHR$g z6tqMUh{fmoK%htLGfp0oLyyQ1vp@51d3lwk@2Xm3zYV1awC+8GTv3->7O~9I;k%6U z4Mc5%!MgOizk2!72>q^~{r$c-XFI%%ilfm=LKelm$d}AV{O;D1) zMFA=*CFkX{h%TW3I{FhfcLZgnr6WlcY`J#w!!BCtSR82KMCBH$ynmwx9*Brc$qCV1 zPniLL404nVnn2P&fCg~WrcD-SiPi~8R~i>>koG?)SBzM&q1Br^&6YTla2xXmiTr7M zdm+Hh+L@L_bZ3!`@e1fQ;+SGao7L>kn*pUgu{K?WXr^Gz9fx>bpYa;6QI^i&_blFH z_ojZo4QAo)AV%KUNs3AzQdUbFR@SROpG_-9(Xd|Hm zg?1QL9J*NOi1eUQXftEYp?4KjJwzK3NAGToYZ^q?iev_IT^6pVA+vvl&&-#cb;V^c z0bn?Nm|IavA{34HhZ#PKj0^zloivk>Pdb3${y#!QBo4+6SH>jarSrPR5Z%$dBg_yQ zy@)b4$x#>%KL||OTz|)rL}66oz+aY!)o#=j030M|8Ih@c!G(o!7Ml9?@?49k0xc8}Id0Nas*D4W2c ze|;r5pM?Gg=^8zH#3`G3O>`F+c1TlqcSkop#T^1B_aAmxp8#&pY-2(thep)8oISy& zSR>$)2)6{`*4MA+NCOhIZ$x{m){cAS)zC(ZHL~@7x*u=#hPW=)XlG;3P-D#+^qs_Q>2_{G{~{8P_-(wm&rxB7zi|I(5<43 zZJxc0QkZ0Sbgsdr*c8tHN*fLa_42Kax;iGEzRLC(TMTY3<2&?IB&p9>hiUdo$Yjov4G ziYr8Pu;8Hr-IrUm0{hnTaOl6<)SJn`=X0R-!Vh~rGcECy(LZl5OYrlb4t~q$_xQ~7 z<6d3{3REC*LQbi3g$&lS&v(pn?Tbx`9Eg;WOKYiS6t8is)=yh}6|)UZ4zlkN@AUNb zZ{Q(bcv=fEk-K+rP>>5j5o+NTHKmb zhpWi!939KRheHh7kZvdkXuS?(<@FY+wAEFv>;=iy*URJ*9%x!FzA;*Z(<$UHwzb`8 ziFH8fL^`i2TIo1Z4?KipuLIECKdyIEAFYzf3o8LO{TFsXIxcn(_i4W`FgJ|)A>&G+26 zaIo>j32#=$ZQFtj>QY^}L#ozu#H23HmO0}d%NKU8LsEXu+fQxb=1wwjXmFR4kd%Z# z?aP)hEm|w6X$(C?uYTiWepwLmvC7ia?wy&@N%Y= zT=TGb_+xPPMLzVjZ`4brcHrVs!o#C!jmtx^*BZWmUr3mfgJzMm`T(ORGf`55B)&D~ z{>#q+eX+iuKfN@QhtF0xF-FvKRp`A=7;{}MoVjUu^~4L!jH@mDI9tXc*ADTb%B>|~ zEP|&19R$NeEQl|FT4vGh|Iu_NU_Gbn8@H3C?7LJVYeTk_s3c1!%aE-?O4g)Rg`$up zrI3(pX$*=&g;Yc;%ATbZMH_`mk@WvOGv|M{l4$B+|PaA&)CZ??;de> zak=$5+G&7o}oaV_P<)$D`b?9kbNlE$Nn2$gcUPWOuc<7l|YletbCha8-~aGK`v z3YUt(k!~8jKA7o#av2l8AmcwtA{0GTAULs^Oc=kc!MX3mWs@3jVoWdpG$PJcPrts) z);sc)Z@Guz-uT)FK+#$C1Me9f8_ksM;PmQ*??L$qHyVD|ZE$rhLK-V3mi@K091a*Gvk5ZVOJ&& zT&2}ldvb;+WWTsrlNeO5de5j<%T*`r=j&dlDye+ao$-=nj{e)=?ohr0=>G9d-=lOIt4nUm^)$$SAo5ay8&C^{;pnLEGUN z-um^;1Y~^Wx^*-++bNYg_UZGG1{0zmEqvSUhB70UhKouO_fX^<88BCPjKXP#AvqF1 zT#_r(Y7Og%R1As?0i7ux#Fv%<1o2@869P_hKW1wf5!+W4&#rQ>_gB2;S~KzPIMubW zw?55UmdFL3o;$R)QeE_qpFiteKf7M;o3hATGhuM1m+q@q6V3Sg&+>>Io6B9lD)T6? zQ31T;(tL~EHsoC0xR>?4yKZQFIqPm+$czEw1{h|GkArR4X0x|-;3w+3x^HALc? z!sEXZ4|pg9cp6YMz~>@WK8CNaqa%*Mcz4h}f1A~H_=wwXkW7*vKkfi(y!c9&)_2z_ zMEZ1SZZRwlGwwXy;n|NL3&l>9AS4>gW(Ul70lf;b1c%ih&Bv0(k8MT@$Rwd)P@iB}Y|v2=*< zklxJ!xGB=1AW{x*V4D5O#UsiqB;Wb#_wNH#RU0$6p@-$sGq{|-cfp2iyyy&&w!)5a1FqCDK1y7U0Vni0W;N8BP|3l=k#fB zcoR)fzur|2Sd*Kmp`wAm4wYiEwH6?b0dztHGFmDs!z-FLX=3l-@Sbo(ufPOtI&@PS zLb1@Yxc+R)(HJ`8kSg@E?t$G=}aUKkkT#&~Q(%Y+>2C_4dN4tH_Bh zE|T(jVpp(#e*1PG>FihpSR8lsE7NGaL>WdsdHM3?5PjwsDm8Y0{`qj?_RF5J**)2h7$GG&xiMvmhezEy&mkY#SbhKl6{@Qd zJ^gZSx&uj+$N|IAe@fByP)deYXWHgnF0hi`|) z6yj9A#m*@;ZlAxuhMHOvTU*;BN2Q+k@wwHhh-GQ_a&tqWPVg?^i{C+M^B?yYOcVlu z4*>#G0Q)JCURcM$=xe2q$2Pde^&!u18uo3>>}KT?Xt0peZzS~$PK!n(sYYXsHvfAj=QDD#*I_s9ze|R z+gZE4!7jSGCiXyn6W$G@uICnjDq(WsWp@59M-*RK3fJ{}E^BUCqZQ)2cdz>J;q9q0 zXfv=mC?1(M*#1?&e*I_%lAk|6$QPuYDlRTo88~oiX4u0Oz!CRH8chA`uTBI>NUk0G z_Vqt{v=>uOc|VUX(|~eM={F&DCgA{Sc+8nw6B^zr?BLnv&dH!@RAaSj7_K+#iX=RJx|XV0ceir-Ms0!K|akOdVL z5RD4~IC+$52}6z#)7I_|D}(OYU)^`3B?%%YCkM!pvq~QpuI%LEa%ih6I(Gmu>h*6w z-4kx5_E7uI1PKsf7b+BRBUDbX%HyYKO-<4x`lc%`E&k{BZ(2<%AX_de?7V52{{X5? zC^Cu47K8&h#)v`_Y79;7^w+z_9Bt?7>WT%5`si!_+U@`wNP{|HfU!ekI1F7+KOPuV z!=c62_FTc;@w-MwojyG&Qg_OfDU5JQokH%-%rqqi^E)GAVnk-lD{y})KbpV>7{E(A zP+7>J7FwA4UtPDG8j=ACyusgNwDTUW5YN$X7I!!5QVAk%ht-JIL$rQ@o{=aBVL!D{ zQr>HLG9MVJ(G$CCZyJ*C$nYIt+}_zaoUajhvF^u@+zk_1Cn(er`tC+xeSb7^e|eh3 zLzPk!a?=H)$T72OYUZ4qTl@1TR#^*L+EEVE2*_T!PoHUyDYM zUJwu-$g`r>bXy@Su=m0i%je5b_0Z=c;HduAm{;fCe;s)w&aH9)n2FKk$smCS*sXJf z!1C`K`CsaFV@Z}>g;p8hTKFq-;=qXmB?;IFO&ciP%g% z;HWGm|2XJ1U?`2vi59XC$U9RoSV9i{YqmoB%imL~$zC6BLO`t1cD&moA0n2mZG_L) zK9ofe2e_m(@7tfRmN`>!=)mvB5UZg5fp`JTh9?ew-^?nv)x-<9lAr`3!vRDi05DlB zlc7k%N@`|FOZYE(as6olZ{IcqQQxGoKkh3;52}T+W5=HRt5YYjUEHq|>v22l&6{4R zykSw3*oFOo2ZCwXgkf264DWgBaW0>a^j0FwmYn8F*Jn7Zy$$kkm{~q;$-f<=`l=C?W_+fiuHK9hLZ5o*h2LTm>;KdlVQ4f~RzW0VEI6!Pt{9{;H2S2g(0=ja#UcZz zkK}=g6c2Vv#mh8g@uKn7&}iAMTQ|OhFcQ9o2K2l=t-7YBh%v^eH?d0aD=hOm-FwF8 z6JI=r>^xjK(@^O^(S>-dWxwL`0v<;ZBe(Y~eNfZVn)CkJix)41i;h8XCbc9qSmJE7 z{X0WsrPNQYAofdh&S39f>)B*)#MRoM&b$HP&ZEbs%?GfAsJam=7Ff|xdK(dy$X^h_ zKa%7wSllpcq3|=1wMbG%6Axyf$bz?jhJn(QF2`!?;erpyOS^NYDXD?tj`(KFBnY_X zG>ovGeo&#sTeM^Q_DL06YIMhP ziE~B^=@}BXg~0p}-%3m^(@GyxUErs*4fY+=1a+Row=z^A!AOwO^s{-nAzC)AsYJCx z;7+D;c>s=Q3ZO{gD2s(CaS!oKSleI`X_=5t(J2m$)KD^Lp$6}Qzaj0a5e)#y-8c;S zbbwDO24i@%929JLhV!RUbyIuZ12X;>WL>_zz~arF-0Ua>p7}{zMIuXt1Cd++wG#4R z@~c-NglnKiv2^FGlWcqU?!9vD+B3*B*$8wDepXuggxOz5GmletE9@W@7uDj*QH-pH zE<1aD4_sHw_rx4K;kpt(nL(z~8P$g6d1$x4-RrepnEnf=4g6>K*!6E8pq<&{?>{qa zUKHm^6uH8S>d-+odkMN6BnThHbkzT2b^B={5HyApy$OeBV(5P5*l|jm0}JYt#M58H z`xlWPtg!_#EBr28>SVMt@pp9;c(qJSv8L(B>(bHD>5h!qf5Oj!VrB^?2ci^r80AHV z#d)tzYpf93hzTGdF;*%=I;jBQi`op0(d^r!eJe#iLN7(ml+@12)cTT$iJnBc^KJTi zd>SExXRH+3{yaKk0D&%^h%cvw$VUyuOz94W-7Pt~9-oIit{-8Q(h3k?G;oLqCbr&P zk_v5~nB-6WDY}SGoja>{isW3$><3#s2!9`5cUto^O(V=2tyEP--jWu&y=YJfx#9~7 z0S$Ky!hxt0AgeGqz3&JvzY7E_n!is&xMl8>tA8+q7xO`2ZUK*p0utfom!&w>^T{n8SI54gC6RYB1gi#)hFkd-JkkNqb};oncNDuW0VGg| zE+apaN7IvD^Os-YBgKB0L(k6%_KY3J3O+g=+N)LHaB2fzlz2k0j=O%{>hc1aUV0*u z-azsu1LEE|(z$dl$4tBoPR?Rofj9(-=M%%&*mkDAZoPC}hAz5zh$44i?{R9KA|5&0 zC7@MA0T~l)F`(lM^dnZF%vy~WAs=zf(o4mrI?la%upAFcupJlTC?pCcW50!boo7D~m@jU*6{nOo#V8RGIRJRtFp+1$#HKw-t{T z-kO9poFMP#TTocbS`LP4?^)3!DpN!hC^p^L9YW~0%{ZbC7(W>VFYq# zr&1+ifHavbBIr5)>0bM-H$)I9++K^Vr9c-lV;7R>z|KI=CVzL?FsyjGQzU!j=&0`} zL}Zusz&n(saDl(n0LIBY!cC_DJxmlOO!37{k7*~fs79j*6Vm-~{3Sa+c#y&s=ReF$ z<&8_EBw)l1OYwiW;?Bc|e)yI#Ko{YEA{PJd`U>1yI*(ckkzJ4kH}!qz0e}&WULSm9 zNX;_PD)Dj^1}?2Hv7FyDjS`$*F_jRAnJ>{+Y(uO&8ksxdQT;zd(e!<)`vY>}?|OXy zRPxo7{(@fUn8WW zVk4xq@$>T&o3me^*Lwz?SZg@u|G5CWH0Ws=y?zg_y}(_Rv(9uUpsxIeKL-jZRXP~) zsiXk-VPJ8TPsQ#>abfG*P)j-)cbj9hzHT!In+_ggGZ_sV#jTXLjtjWRQWKC0XPFj4V4smg_rS;8t8-qi!`0?w!o?-11-+f z88f8+q@w$5le`9_kp7)Nqb8N2cmcX3G{G{grQ;Ag$Mfe45HCwQXF~k;YB}auPG%Pt z7|3zju%YC#neY%1wL&13ZAcQ%qwmbTXjbs)(*veU><>Jk^NyxR{GNHPzdz^w_M^Jy z9vCod$KCdS2CsxKxg|tMg0v88SmjjSU-!TH#hMS6O$wLH;u{lhKR`}P6iM&MeSZG z(f3fmLb|4vqmE2Se1mTr5c)pqaY%05cr5uD{zF5CQ!u2@UWz@Lh{;5Y*Fx0~1e zYhujCgo&6cUS!0ir@*i%tLl1B;V>sJU5ZO93K>vxs2yeVnh!fx#$JaID)l2rjP*#TJ?YC+*~!+Qdl`1IAQ%>Q=&N8pD2_7`k6KmZ0-XK;9A zXUo3rZMt`mV)i36)TH5pY-OONVJDQ>*ze^r zF<%8YY7D92d~qp7WZdI-KVj8U;~6@ zxB6uTJX*$c1adPsH&@?zFV`ozquPmPqkFc)>4xP-RbKtJMker#A4qqCTR2cIzXM6} z_C*&-+SANFB;p33=d01s1BEQj4=M5?=C%5OegM1Rp_C!~ssR%w9LG@umIMYWe>)AX zqYWjsG7t|M-h??f96%m{IDud(ivnxiy2%&Ko^n#f!kgH`BSm7iTq{I4Zg4YP{eC?{ zjS>`ehb~G6&0ut;rK<_Z0AUgfcmWWB7hYX`@Pf9#pPvnVun6t6@?SrBvJG=rlAY)V z%F8nuh66*uF9M^QIeog=L-XjNp}W!F!-dl4#a`7VrDM#Q0e8$@J$7=^><5mN7*r~y z28(_+%rQDxC#iyR3&t)V6#Efxe>TEwfe8s_R|;MHxyCE?4^+A{((;`jU~nIse@18TefdXbj43n;Wxw{BAa zRRCitJo#v?6cqNSt?JO51w*z>e3^H5pxY*e@>0b3l6xi~0=w=!e^-WuU+iBMpY%jP ztZH|l2keB-0*vx06!xoErDI37k@yd9?ZATaef##&M?Ga!OwTEC8>Ib^g<>u+ABL={sXIZhdd%oLiz9du*pykZ+_Kwf&Zx5wt?36Bj?N7=3w5Ah z+cP!K@Al8{H#zoP(x2Tb%z-{x&Tv;`&O9_yIAGH|AV=I>D8hn@JeDs%ZGO7R7!kO+ z9A}~f4a9%~hCfMw%NcWtg@ob`Kq(Dli_SpZuHj;xLFYqg%$|VV^9?}-Zu2S6o*jU_ zCqPJG(cViJEzHfCG#q~TQ8^p_h6b|*0?&yd+T{ciUL~3liMd_7?y!uS67lpMX2bNT zs9enB9k%mArF6~uxOmS$4TWYMEDoP4n19jCWaXsd(&07CJs1s;UU*Zm zlme%ya);k$;4KUv!=6#DB?cnszW98BX;vFvKnc8G-?>*W6AnA`)v}oj zR8Q8|PH5ivwsB{RS*J!W>^xW=b+>M0rHHESOaMUSnAX*WGIK@gz3){Ru(N=w0Mu{DP<#OiUGOijfA1fd#8Mp0Z5yZct%4@ z*jm^d->y5HK+Xs-m5Od)1510R9i1^g?C8;>n~FpkPIK^%E5v*MWYlR_v8Ne^p<><% zw{|HBP+kCUUiWwdeXCqWXr4F@OAGYrbtj#*&;FjR9~`-6)v9F5QEtX>l<66lT^$?( z8x)7r3xUNJ#i7B`6hAUp#yd$6dcCQM0LP{+LD=`(>ZXU?nIQP7LVND|GM&OEp&`&xt%El z61UWgu?s{FSXB6h|B)35LyO4^W$f`8XBxV5^HvZ~aCN-z+R^qJwG6s@f>slYHPezZ zoF~Sn#6Qt%q5gXqEqz>^HWMJeCJ5o|0p3;a3FCvs-H>_hD+cvG%WukDGAF2Eox{wS z`Zl|~mM*;EGK?M|`R?8AQr)9QMXrznqMdts&INW`^4h_hB^gRXVm&XJ(X$)7wop;W ztw~cV*4k^i@m=#(m&)-2VHe&k3yVKd^$OJ)fw16(ge>KQ`ZN!q)ynNhLQK)Ib`1p6 zp+_TQ-%C#ZmzEhE?3%jUd|7w+GjLX9DF4o+4`BK!Vaa-yo(+C$yHtHbC zBmgjgKG?WR@wQ(-X*V)xB133hNQaw%62!^{EG(l`L$cYLYf~H4<1{Hfr?77}3zcDVV$@8(29wd{8kIS+L+`(cvsGH>xsFU{McKj@&_BZ`RS) z#zuJKkZwNYv`SS+!pPr<#@>*V0DDGwf^=x$XbDqSo@gu1!9*;Anj+o zcPZ<661GhOyt=SBb{{4havW(4-oAaS6&DMgie$V8CsQI9gAb}~*S&JguM^%&dB zALW$G^Py^SU>SmJDG^|n2`;M)w7iqdZGN%cX!tTcC@B&!Lg^RN25C1@VeSW#vbArfVh&5KOM|VJrz;5Y>WkW z!)fdXpGi1|E7z=94((YiJ82R}lH6rW1X)H_vC%n-(pE?3Wr9cSsB26cV8&ZUdHH)i zTNG`*%cq%%iAlly#GlVC`d{f*d&@8t|8Us%H1B|CMb~Pm{~QuwPj4uFCca2nunth@ z%1Sam+m4;e?!NFnzt9Y9v7PefbGhVU*#pCSrH2qMDWnA|EU9APAq=2`#0dBj1Q~^c zYTFz)`vpJh*F@2H(-=ZB<18PFQ1Cm5_D2b6LO15*xN9$&I5%Z8$jVGeqIB{CSN7I% zk`kJZCS;=L1fIC4S@Kwv*)eJpojj;UZ!lMa@(%zQ>z409f+tZpQx&}fYoNKYFiVF` z^Sg3lL#LK4TPkDlP0J!{-r;z)<3{pDC6)6Sap+I2<-PO>>_F?1G8ECmC-2ljFoP z$Yc1>v(8IJ_dQ*Kth@oh!lOYJo?J$f%Y0Cb;&9AUV|sy?4-QhU+1K@Dumzw7U4n_U zfGU7A?tA0^{+=Li3;Xt+MRmb9ae!Qwlbh>C`|4H_9xQ%M{QQrUxq8}zpgGaP3cnVT z6OQRq|NQdmu>yj4xhs{hg4F~FSRY|Y(hOo#Ge6RN|MN`A+RS@t2VqPlNiFr|%Ti!n z*;FL6Y~iY~J<&3`JN*+d^ZLa4F@ztunH#_*mCdyu-L3e_AjOX#UOjL+Boc9EI~1xC z3q_?NcisN-Cl+B&QQNv1%msZANSEX-`!@L5WMFZl6K?RvFDgv;1W{lq=Wdc1W+$yw z-g$FfMLeQJD*I5`0xaGB7~ITd*Vn&8xPl zby*H1XVZU+j7t=I7mp4jk_l-_NIv{<@s^ZfCQT<5s#rg;fDp@Fh!BW67wt>bZu;}d ztK6bb)YC7&2KO1k2($Wn(Eo0^YvA_G!j z%R?bgKcNfcCy4(FXHZTUzot5*5j%i17&+3>!oiWZ)6?&4>Lz^AW!nKD&EV?2j4nue zVin%Ar~8|FwYZi7{o+8xDTAShY^ysw>z(kExUPc9Mz3*$fkSEi6t)G!QwU7W`G8{wc8mUMmz`+bJC@oFrVcf-r1OtEpRC|J~ab3me8LE;j!Q$}%oyS1D zgXS$eQ+(OVMBxOB`3r+Xv8xAZkcCNT!nDVYLz1nJPdYkq1bX@ozD&#(rZh#xvJmwg z)6Ub3vVtw$cXjAzymZXbAHOQIWNr&2M}r|t$|%>QfCIEUNl6PBDM%3*?QI;` z%WMR-Gw_Sa?AaZRPj^5Bz)rFL)^BURdT-xc%(o&|Xwx|RxYk8Fnx33BwYx>*J_Ely zgQN&xOiV{;$-1B!oMZoHS0ByeR@2Qv6_Bk7q#Jl|g5keA#&=U6Wli1%ZUB*%&3W+n z^O~MD0RW?e=oG)LL5o=oVza|c3dyWgbQ-$4N5Qy44$m8WP_I@g(WKeT!wkAaz|(tN zN2awvG--r6lJF^dF{1WPNpG0`h{q3DaN*nbH@hLF6K@<&dOk=Swex*IHbhWXICRL} z=s4H~kQLI#Q%)KYS!@|-8D(J=#&9F-^Yy{c)!%*av1=@G4}5bIo8QDZ@4$hS<^Fp2 zyc#!^O<3rO!{g#=r~I%`n_W^^m;&x2(Vk?$JNm-VKr){+3=%JwE>)2&1QRFr#dJ$v z$@^Uz&_TpnmOkEOGuA`a+veSz38z5tSULG2BU{pla3-J~tMaHQ7cz|I;KQ%9B|;|A zX?j1xhSqdcX)sh9tbeQi`{I&fDi7siEeC$#p!1na5V~J<8#Yv9tXVv=k!2w-P8~fo7Gw=4GR`RYu zLxjX8HlZ>|XD1dHJ3k z7Rm!5Wsuu}Q7A`S$$FuU92PLV6xa#g4oo@puIodxuyMAXG|4+9gs>?k5G^OzO}{sf z;S_DU#T8CBJw^xkr)4F6jw!ffNo`2LBq>rfarR_%@Jh84x~Cy3{V)P2fO7+)x7Be> z8VPHVHXW}S#-v)%*cD^om6~y&@qnR;94Js4VWp6I^dj|G=LeTgLK~%Fl#Fi~Ps;{9yzk^{pLRDd zz9kkC5%-|ywp6aq?YQ;gl0}Qy9&+#L)0mcKKQ5~MQv=4<3ZF6Iu~4#!Gi(w71(k=a z52C+BYJ4#Bx6*zJg2134o5`;UsUn9XClW?Tw-U!l^Br5f-&yL$>t+B0O-TjOpDgNR zZ2AtYjj0_T95=t->oxIunX$eibJo5*4v7+ghF-wgTzF1m7@dXKd(xwOY3BCn+&P8n zObEl=(3?z2AQ-@2yeG~5ph3Ljc7e4E)#Y`vD#HpD8}`l8>vDb>@e1jib2>cRvjf=i883828yY z;%Q@xl#h@zSXMZF!0U=uh$45v#&iEgWC&{TNSIRdoq12to-=#BWozc ztNcSx3d!=z8k5yLn2}NVh!ZFFh#t-tXk2g_pXuy;_2zo^M$>+|=9+!dlx3H=M(>|> zi1C{Yo|(X7XJ!i0iZUIHS=)> zA>c68;cb5}HT!i^w*Lv-N{WD;#hXQUdKU@c6fPBhSbs0t0 zlPMjnQ?1(4Er8~x*j2nYqrEFj{_9CFxta_Al1VrkI1NgBI1_0;<&4He%1a78alu-# zO`9O|P3T+$Qx&eS{f8GSpf3(rUT;H-tk%m8wdbo}!4Kg4B(y$ZlvD?F@?`zWRjY(S zkPuiJ8zi(Q41P`OiiCH-kD`b&9-xfgnnn^0gURpn44b~a}-HXtnxit<^`rZDx&*(rnIDbb)1@|_A#mksZ zJXl2)$q0bQ9XJy_^DCYMIBuxy8G7CSa`Ef8_$Ndqd_Pi{Av;NV+8^>5#0i@U#Pa9B zQANw#-}7)T{KMGdqv{INK`RSzI>)`1y#BA`$@&F$?K}98NLDL z`%u8!<-{eE3J}i?CE9{}!ldm&6}=$ZS^=O|I$YTjO<|ih0>;ukoX*4;j}I8=J^4bt zu3wn%q*iwRaD>Pt+ykft(|(>M*i7b0@OxyI%GO_vY#Uc8C&i$V(F;hNMB1tykfgr)#&E7@IH zS##aCz6Gud_0e?rC9y{yt{AiKwR#72a5p+9Z7ws8#uPcQhX-_q_=4qI#F;Zj1C}40 zZWE;$$}B?`D2$+9!8=zqs0K2~Dd^!8d+!n4FvQ?nO{@+5cN$t#=oy`3;`OV|ZUj+I zf$N1-R$#GqZ7lU3sCJa`u(fFQlJ}Xys-L}l*XkVxb+Y?9A>Xl?YTri%uN%6k zEy_4J{X>n~xKG8!C#b3P=WD!RT-ASCB)eJ3Bix-`r|T`Qm$+b(BwH+Ccfm2B&G|Pi zJhWT0rU3e&bSV6Cp&rsiZF!Uc4BTJU0s9OM(_;ZXr%9;4z~hS$BmaJ^&>s65`e??} zl9`ZeH7dD(XhN{vvqFR--*(me8fot^O|B{kgQbtF}@enXqry=^R4Lqpd&v7 zuKlo3I@+yGTgp5D;sLyJj0ZD>gBMTlN_ahEfP~zh`lg zh;cUCh726|_mrjjCyUhrCm@i_j#?7@bLz=iaxYJ`u{3o0OX1cfAlk;lKos7>(#8Dz z!@4~6g{o!6#V@@tb;~CS1itzBBF1=bI(9tMtttT0ZnnI#xA?F#G+z}J{f<6u1j7|4 zX;KN@=Bk*gdByMe>yx5dYX>y6?(oG!=AK#o?iE5m5F>D?!*mGY`(%`eSW68l2)O@O z!V)_4?*A>3qOZ2cj6WpKN3&kF$sH^%SKsjHaxDOok5F1qm_)JJ@po3kkBM4dk=bP* zeu&#B(>77XRNd7}xhiY%Xx}tzX4QBmR^O634IKd85bIL*K3@6hsadze*!LmLvhGrG zf>!BizudC6$(&;uQjv3^9C8_#$UOMf+HwO;&8mHuobQgN`$J&wS~dH6)YXcD)4qCY zO`Mywj)s>*hHA9%?$Gic|L#A#yLJrx*kp3|ht0%bOxw@}d{)|X_sOLBC>8|EWJu6| z+Gc2jO27Tfv$Auy@EP*j%rQpC0<(hBtpAnZ{7u>rf?LmL^liq<_Sx)I!8FX}Pyib`qE|fZNd@gO-jxU}EANRh2P~Sq9_bX6mLe?FTbtaJ zEPhYIjgp0{setV)Crnv7{7ccqh7eYrT$ouuR3{qwoBD3%ohvAJgr~|KSmNO7YwPIf zxE!f4Bu)HZrD=xCR!DvN4jLG`gbpN6h@(KLvhaX13VDrsy)lwv*con6&0yY z$i~P@0>7{4i;CzBbPL3?fQ}K9C-I1a6Gc4YBm38DzUb=Ie=ik=d;b{7F#`t=F5NJ} z;^xK~?ebfS_rwyS&HdZhC*xrfbzgqh!$! zQ7*G^!Py#zXa)~j(>4oef#`bizoexF*Zalw>a9-Qt-hWW+UjQsOzI_xnthaJkH~)c($_L7xciCRcBV0}a@zbn z^~aqDcCi_&M|i1D+VO?xnvvI9T zfL8Jy=}vE3*W8L>|7hurAAcCe3uB1=d7TQ{SZwv2y%p;!C8uSVTOZ%U!FT-pdi1VA zLtcaxfAjL3>kxkC%q1%1AEs8vQ>rIdf6x8~E4ZcmzIpV~gAEMD0YRcgb zg;RNzg|9n04!L)xE=(Q!f_+H82M#*1c#GP7)h69@PcHD85ZCL=u5G$W&!0nS z|D)-AL!+e0&GdvE(0SGRKabza0$v#_uwVcmaj*4r8ICLy@^iHWIci-5*o z8Zi_pWcWp^0{2Y{RPBJB|AQ@E}slm{&uC-fZ6W-&TdTy~gZin#oY{K7oh6fOR>C02?LL z(og2Kd5tN+^tE5gZL2Sq&Az$#za=?KS*jp+>%)is&c|b7Dr>gw>|~%g`_hK8ji^kf zTe*{O1Q(|Xaz8fOs(kYBv$|&!YU{*f-qBG?0cqk(JE$~F__ic`{lgV8pW1h=L6?g- z^zrN0yLp1t0)g|5I&YQ5^pt7*P1#cb^h@Fu-5$MBVz+Z&Kfg`m^UW!ht}83-k4_$5 zpPDtd?s9FaZ_Z%c1H{fA(>LAHrI0*JqP}cq zJybFuwhc8-ub#>^$MZ;b!I&YSz#}F#AwItD zZC>ABnO_IQE3{g)W^ zENs-FF2Bv?YYL-x2r@71jjtPpEiSHG-jU7U|AGltleW2mep;4l0no_W1L+}du$X-Z@l)e`ZmYLVBOlaDOQ*7p9&1^&@%~+ z4fd7OUU`)Z7QnZj(|Q#OBir$fU}%@e6(}yual*!>k)H9v*lX7uUfBdsE9o@5_REaz z`P;g@YLm0Fe_}H}l%5#W z^FEej3Q52M=!q2F7j#RkB!rKoXz9`4oQ$3lu&I!v1XbZJLQi3f(3^!7>vTvNHd{o zjv%~@s~CpE+1U?sQXs04{k`>~^sCY``L-M0tcw-cnTm4j(;dg?r3S5k_1Eyp`J=b@ zHoa6cV4Jbg#0tBMaKb;5Ccz3_TA3kWDumatf*x>Vf+Yw-vo8r<$9!J3uYFeEjT3L5 z=({!yrrurQ(=o4MweugKk`jGrYTj4=U|e|OfWx@HO;Z1YPspsHPx++11w<<4J`bSB z(+v_BC*0h>7Bh8OEWq0$tnUc_2J68rkN$Q&w5wweq-vU-Sv(qHYp#Q;0;$? zJbnJ6>9Gm6kR4l~AFAAOj_*kFHLpo@8{thy*{^>Ewqk=qGQW|s|4z=Ors6q(Vx1G>H4wF3^M|j5? z=?Tdp|4NKHn|FBa!Vnb+9L`|?7osP@ZdCRw3eiXyolMGQH-)Tk%uLr_e|Z6nm9em= zzf~rGx7(7~H#p;Ycvzcle(f&X`RcixxMTS1mDOdZ9u;bb-@(iuHD-)Uj||(zdFn4m z%53T9iLDEFr~BtI*0dSC4>G^2?)n0fr# zR5Z2QvPO*Idqez}X@R`cH7}jx%>jBw;A7K>S!62T;Mwv%N>SN+V5;!o8UI5*zXzrM zZGr0ZiW)i+(MN$bVJ~Ng$&+%-#LVW)+7d!!6%Y#H_z^7FhDc0kH=qXLcp8kYwly@m z)avw09^ZZEY>VvAMQk?(((R*Y!UGW854seSVdl?@gPTG~O#wh1)^0giW>2A(=xjR6 z@lL&y@v7s3oRL~u(Rc2>E~|!m>}hPrS zZ{6T*Oo5jpY<$*s2%&mZ&kZ`EdwY%D2R0%;CQQ#KgJ15<6VFYO=c-^|rxV_V`?QPk zOJ!V=HQl9jrw@{c&8HCX!JvNn^5sB34IYgUoq$T6D&u?ZX1nIzfX<_*_&k34&OYGy zam|=J(BiG<-WopGX-h1B7)HKpd|&pnz3D$;f~E15HX2>byAPU7qy|J++&QUhP1}^R zYS^NpG%Bi;+RN>wvvzMG53cRM!SnXY+}QO?7FPY*)M0ipHw7;TarWe>I+)$fx~nsK z`MtK|mHjW6*x%Z*a@oL6^u_K6_U~6+vv${o^Wb0jFOt!5>2CviXVX5xGU~wG)ljnG zK>DZNp)9a&Q3zxUzhm{zvbC?Gw*#_;W{+%+Ste2xYorW~#F7lWR=IyFKMed?1X^?Q z8g}%u%a52?KZv5shO8vGZr4kC!5%$N%hXU7P>mQd`-L)2gipRG10gex9mo8Ek0wVNc3~>9Uw#gf;Md{TjE7%$ zT&v(?@P)&a<6d?#dW^wvLhbcFscNn3X6%RpA&^NWP!Tg(8!#f5yVBdiG}pASN&R7 zFBmkcQb5JR!`BuJ6{mi%{Fus=tGRY##mWh+6(oR}XB?|1<005#pCN1PP|VF)(R){X z5jbbTku{MeUynEG@1^rJ{?Hs-pTCwf3l#bRmLP1yQ>%#qw-A2PlM(n?Z+o~RsQxE4 z$-Cno_w(n^6P4JxC!>@ix5T6?&OiU$VC4J6cFO)Op#?CSOroeG zugX!N9KZjmHYw{Y`|mv2fdZ>L`ah?`3!f4H8`EDmpvA<*U~0?wnH`eOW|z4YpWL}I ze``w?3IsdMf%(iV3VyVxvcEKysqictoEWW_FyA>Cy&-e9e=le~V0LZBM0b}tu;lCpJmY>N`+p^2T0pvyR z3?uf`o4PX4j#(wS>-QWULNg*13eq@Pc_-B2if6CBOgg-z3v$@soo4k`Y3S#BXG+oB z4cePGb{v!xI=1rEKxfy-m0M@KBNrhXFjA29;52fqxp@5SnYv9|x|leGHJFLZPCZhK z^P>^S0))rF$nPolRmQxKS2t2-d3WvEZBQaA7jj%i6C1sCT1Z3U%deLeuBZTsZIL-N z2WHlFJQDG9r-9io_JlHvR_%D2`Mgz$p~e=%*@9tn<8vs!8ZB zBRkPRbw9rSw0LvN_DWnLNu>kE-bU>p2kC7Eiu}jxI$vSA5;wE!p{@%@Qd}}VhCxC8 zo8v$EY?hFAz=Wt936kF+_?YZ!5%F{yfXbt_KILa_=O2BzajEC-y?fF57`-?iqP;;) zCRd|wFZxgM3DcN~+NR$>ih%6IsFHqyrwFu4@bpE1^dA-nXK`q#8bq&v9r?hYjrjKm z0G+Tr=KRzEhNEZ=ICe}3rYx@sar)-F@?)k}7d0!(y}3*4Z^1&s(NOmJzS`WMES*V=bc0 z+z%R?f1PnEw25DBXxOy#&&9Nw0WMXQR3xxZtcG_l+^mSn-nBFYB@lsRenj>J;2@#W zgImSE`EdC7@MVUxhHW}INK>t7CRP;aJ2277ws5rw&{5i|C6bJO7%b=;*WzQPyfMG= zmR|*Zsj8BMTr_p+34~zGbtz?lKQZVr?Aa2^atMP|euq>at)5q6$t1Ne0G+|Qbr+rH z952xyqP0zamS%imN9A=wzLhUv)9@1HXeD`I8o5W3#Z?&Ch)@}h~hF=ru<}nCf6~qF)S}IiY0ETJ0}AEu-dl zy_EN#$W_*6i)fC27It0wIrqj`!N6l^_0wTp+em{<&ZD`zw4uP5$YpbFC36LSN>kHrUhOV~JD+U- zDiFX8@)kgRPd*$rSpwKZS2XwN1_0LeQP*ls@#unK3pb6Q1X=Eec2g1BB5Rjj3e=5o z(}{W@(Vs^QG%|V#k@w1?i%y0uz9%3E0Abzw)HqRu+91MHCs_bxU>v=a5?8_{2g+r` zO_u?e)_L2s(CMuR<RtOr3$N~|r`){c_ zt+}RpyM9hkD;L~rqhD?+a@INAdC!T;MH?qb4W5?BSfO#m8#)V$#vXhH2)?RJ8!)Or zGwEa*04P`%$fYfFebApTy65)yJ7e3XO&ghe>r`zA+rr-lpsNTs0+F3nyNj5P`HRRn zM5&785g~L2!P=OatFS}h4g=n!X9drXw)C+t?%JhG+L0Cry2j)D0Fo@sY|;Dzm~Q@K zLDw5A6PMOSF}M>OP6)sEsI-M|&HcDjRGIckX~Xf>3JN9d_fMo9LbJfe0b(ZJy1b5| z=5cJPNTyv0pBwVBO$(2R^_I`>rsLD>NcN#uu%!x^3B567y8m@ri%iWxV5P=&8 z=wcM{G;Om#GdT`FgGu(Ji|sqINC~hntUA3?*m~7GCEoPdS1vXK*+$V*e=*R&VWp9& z$@p=-6*(5x_C^@oaziWHAA5nz4oybkIP(jk`0nHlZkqnt1xn!1HM4zt4mW%|_gGlU zCMP)V$Xq_uY@X1UFCeC-WAirsmDz^F3^No(i*OMmBlqz@FxxYttRz^0-gm8CaWghH zR-}cZaHUZa4DPlp>{n;O-%!4X5|7cbJQ ziuH5nv4Lsm9{U|WK4Q4QoX-#Vf}%sj#^f1nPf*;7I{Mf9Zi)FC6Mrc10)z6JTdS6l zeWo`xBqau$JM6j{rO(x62PfWj)XBID%dYv)vw!Qa^_t?>b23pLXkX0aZ4aDB5-p57 zVOvpiudy`uU8(rB7s?A^@6(o|>V*b12<92Q8j5>uYmerGg5j}yZ8@X%Wqbi+Qq$VL802O-%XZIKqtX3zNvljBUQGyR@$8&&pyU+ zE}srRNNS(i6`cekxLPJpE@tP7Y}F7yqR%;(t%I~cx>{tMoO5w&krrme@m&us@>}TI z_d{3tw0|#*&2U#%1*7Ba+Y*{MTS3hhWGVAu>5hIv%P~B>X zvKnRB3uioc!5T9zn*l&!!3^xVgmMi}CLTSpI>+hhpE)#kGBe^{yG0o)v(MY4>HdwH zI4J(BOEWY6(Q=aZ9N{er^@?5$M%ajKbH{eD3z+JLJcN6LJnX-$jOnm-~>XP5oExc2kss?TqNfwGQJOJhq&ViH!U4euM8g!71>(DV2ig{sEm zgV^(tHDuII4yJG+fU2b?TAJKMGx3)LU!P~51L*{{Q)8SHSFJjC({}$yCEMb#hmL>tk26_+X`3{tKbb^^4pLSpGf2wQJF5D?QVWG=IqHHFNT9!(XrfK9Jumxiq`n2PC{w| zjr!)ZWetEdS^vE7LZblxc*gPZA?!LtpB9J74Q;SR+tF4NIu;m=Uz4AZ{p{HzN>jgh z>K&}Lnqf&p4&F%LMa6r9UJt_o3)kGqE-qt0V83HU*NQ`B09e>a5(_mWBTvpEeYk{% zz|2wFxv6VbujUc+-~ZD8K{1!p=I=yhpI`Fls zu3iczIvPF-bt{p=e8p|&F&%$ynjjc*Uh|#>DT_MQ_JO9tAQ4uBe)yabSi68}tl-VF z=0jG?pp}>L6h%kBHBFLA+-iZbPqA?iiRUHU4NtVFjX8N^gfg;PH@En`>0LtN_s)thL1&cFdSLW@ayvLpnq0BXe;B00KnljIi|w$RI=F zxsOg9NOHaKc%_iD`M*SgPn?NY>S7uqUf<`>;|=j1vGWINEG`?X8YFQ+8~r0T$Ds`q z_u@;oUGGg9k=pWw8egb&X*0@#p^L8g66tvYBCOJedJvmD`cz+q5|gd9j+gXIIFHoQ zn6%6OReE`20w>=HR#wc~S6V$DZ$2qP=WqW2?ubl`YniSWl51C?HNRP=hfZkZo#T_A z#o+gyo34Jm9bF{&;QN|?r68bZqV?r$0J+5%?9+|o9}(2ORvpg1k>L!$#Yd620M>MG zMa<@i=%o@JV&A$C;fhiF*rO|yZ3WmW@Iuk%u>(gR093gAh}e3_R|hPRfOSh_#Ln0& zSLgu;^f?3GaQB0QV7Ejljhb*N1}HIdw$y#j*;wGmdI!T7)?e!CqjxMT+M=21o!on? z+P6elqQMz=34$e5g5dvbHSzDWE#9cbR5(l%!u*z;jf1p_2E$?OF0`83@g#1Ih>f6;o`xCj`mwecU!<0#4i&CV1QV5|_q za!jotM-=-6FiaU`Rkcp&K4Kw+OyAT3RhCYGW(7N}AMo+{QM>EAa0YzzLey3IurR)8eg_H&y9jayN3G8ls9wPzMy}Hzyap6{gF|qT5iq2=2hdhW{L-GQJ{Cxr^T&FWb?_Ff61wUUnq2Y}UwXLYg(6K{W*$sKhu5Jm8 zZydQk?w_J8r$jzDc51kS36?)}=!_LA&MzE4pIh(x+m+al%@G=X@4Zdj!Dbd*Tib(r z2O;iwWhG_*2=6WHDH_@HkdJ}gq1n4RfhK<>9QbeJ5ddBhO$d?f+8x(0ctpqDKK2l1 zDHQ32_-rwzM5n+B>F;MM42$XZ(3 zY6=b8{%&1;?YBsk3w|~=Vm*1kv3ylod|Om`^7-l?e5Z8EnCTdGGYb>%GqU4I>IIpp zQu_j)L6VtJTfYSo&^_FlrqO2b#<=}oINx)OH{1iNXJ-N9JJvspmfQi+kT89G%gS5E7fnSsC0wU_w&(&!j?j=-v>>Q358?nc~bNP%+Z z!HL?N1b32Pz>eM6LGDMLG;N{UvLv0nfg_gp>J8dlFy@Fnd~h8Gzo_%n-k7FG6~t}AhL5(BD! zxCe^SAf!*t>&0qvt&mfSvjQW*=!$hfo3eO~G6Vq!ftF847Ftu!T*Nb$(-%nCqX*dyg!?5O5zMKBjjGfHUUH$p})t_K+O zV4GYqr_5{k?Ml-C>qFF^EtCM-3>)s1m;qvpDiO%UBbC%#kcn?9qO8ingv__{IalV5;7;484C zV@8!rhA?{nR5gFg znzrR89)mJ`ul&KzDA}SyO^BVS5ITwNhxY$8TqJuvM2Rh}8_)^r&@9Xhp1u3IwWjyl zKYG+Q0E&p#P*jM(9kSEV{^96ss;GNjjho6UxU%YWwcM>CdIWe}dJodcd_%t-8r1Cb ztDXva%1WSvvb3z=>`^JgWxH-1L0jd_-yk`_GTM zUOHHiBp7r2Gz8o#e>)BwFo4`S#-2Zv)~bQeiNffuV)PE=j1*16x(3YePTL^yOoCQ; z+XZNAG=ujb-jZ7|sB8>*8IY^5{F2ZQ%&LNCmx>1{pZfD`jr^cw8Jic@ zliVIqNQIw_!)Kb-K_r$?ct4EOVMGj@@r;S*I!s6qRas#n4rKTZ{$v1l>6AhE8Cs*k z0y)EraoUO%XK&uqt+**}aDq6ziSxogiJ@T$x1 z)vG!7_eh9wACIho*pg4#c8@L56e185zHS^ZQwDxg=u>yKw`p7({KA%_B$Woa#V<<8N(cfhJQ55!hj< zCSaRzjcGx-{8VwY(;T;%Ha6kMn(Ait3JrTVy3IHPwnTu?7;_urpzpfYny&0m+H7tNb4cve{yoDNU64Y zs>{Jfwf=!jmp;{c!{6|L@Z&Hm-Ee%oo0v!+7FTNFf`EJT-AJXxZ844rYxrfmPg(D^ zs@jsND=$I@Cj}rml-nqVbaV6{lvM{FIg-x!2m|nWGvSWR*u$|B{y8inu&MwKJT_5wv5R|piD50- zht;$z+VOtH1^%XG)LMTM?82}CM0Y%@_#gI0L)IA^b@N+m$~kk~K9QQR>O$kgH6byG z3FUuQUKhqcT8D{?lnIosRQ#Y2#PAK>AneT!<_EliH}KC3JD3e>+9xE`F#10s4S} z1aTU75yx}33paMWo=By5@O8Y}>Y1bxjWg17sb#GkHQoEx#tKT-Ii>@tc<7KqxkR)M zvPN#cg>#qKT$f%HfRZm7;aD=SYcHm^BFV%oLs*7Adb9}$2uNDIK4CHs0vJVwI-Eeo z%MdqSsY9i@W!CiVQ9;1au=Z)x=lHaXX1EkTKADi&_3I7lRY{>w0tTtY zR=evF&JNWNXNg6reXq7H#07+*oP_Uj;}&Lyf|`oQjqPzavCMtN9)Q`fx9>OoJFg-* zP>YNF#23A+YWN^Xg#P2dwSrwVBr*vQ_(kCDRr)!!62F}@Lf-~O?AZ$=s2y%>e4x=CDyVAB+rVoi~-=&MQS<3`qExFa8 z`(g-0NsJ@MV_+$p^3G*P?A@EKXeS69ag~-BP>&*L;nNS|{*6m4b1bM+M!ZTcAoHlf zg9nQ+(UTb)v2g$L+bc(ji7RiHg*D5Qyc$RMrHY#k7em@YSqug4g+ZM_WRxhKxT^*R zjp-CXiI{%g$ETX@)~%Q%-C2Fr#oiL{xDlwtLq2z_ok0Z~{Mj6WA4@doY02Yazs*OX zJ-Au@VzFK%TZVSil}qyl2}4Y4h7J33MIR(w4bRg$vf>6(Xo{^8}dMCVtJZnDogsk&R7xil{78p2c%0F z2*h?Oaq+E?)dPq~`n3I1ZB`Mli}q3x7>St4^U+H;zRM%$ukI&MQ|Gg=nTI3CWL)BP zb4@066r=^0QM-av(mn%X7^#6N0d0XS8YNuxUC%e3 zvqizQB7(p*o%&=)YQeuQGI@!KsV>orE)s$T7(QkiVcrjYBIfc`Q5Z6AxRLuUVY`oy z#0b1sAL8|MhtCcGDC%lv(V_w77ZGX>j_>B{ZvB3(YT@}NtJbUudR!6Op>bykl$dKE zjKf%CT0=wN-?xT}rC(1GDmvPd+x=r$CogELfu8*yeP+$+rqdOoeZsT3u%S z=7(Q)_%RXnA)&~GL1b-c_{pv2N@D>kgcWI%MXX0TOPa{9{2}E$g4Xd7U)7H%Xt4f5 z26-gH6})n-+O&7}QzNP{kqq~-a&j8Qwg>p`s$!3V^^ZK6>CV?J=R*nsCQy*^xx3yPCuXJ+ht1Zucqd-(E{E@Qo8G3;~!2@F>U{f8joY73rrUDs~Am|JtwKBTUC< z>+=p@v6~tDt^q@bwxxm=e$R=0XSnZFd^0$kSZ%1Y=@Rx|dd)^SRFqe#98Dm0{>F^- z1AhGK7G`wRA^W>sXz@B8f&k`NkgTS*Y_-X8Lf6fGO~ba*&;aye)DS3AX9#Z<wq3Q<%xxk7ZilKRtB57lZWQ~`=VVU^%F zgxla02IqFMM?sYM(huZP;-9zk$CB_l;)5lUf)$y^2F3idZp=~hHhBc3LC@jA{PEVa zVY7_A0BfXECeI|5%O`tSEnI^m8C@H(BC}foEtR$qeLw45xdUd|v-+iC^@$Ye;`+Ohg z9OLtTzh2LAUDxw^$~yLVZYf}pS!1vjh86fy0KAI-*lQ9G5~G=rWMg&wx2Q}AG-3iS zbR(L18ylPGWe6(BU>j_jdH?{Lfa_~gADdnH;j>QD!dG8_WnNavHa98h;ODi5Lea5~ z71A7BT>I*k((i_o$~%OrR~tVd@ zfKAuQ6os@A1N_8^{uE&oNh__O;4X?RdyGlqydTCd1SIBm5^SX|q_*oLQsqK}yE!jrcLA_X!l9VVJkAJAoMfw+gZ0lx+u zVShjcu)$$2>U64DrgQ1ukM zlP8&oV2g3&>Kf(qnPc5=6tGwdVf*NAf+dl9Z>##qv~kqan4h$UZ0Nlg?du)#f56KZzI zo({JjTsuQoA?y0>-P+v;9%Dt^{slalW^_OJ-BA|e{-Eu`r1;Q#*45c(fWyXYJvpnt=MB3bp@w{bUF%$$s)b&Y`D8+ZvB9wi^M{tg_f2!?iuyIpkN zOMdjfI9(WOV>+^O1hZI#?FJ zOAv46)(G`snFJpyC-TXG%nO@+UHYDd;55uY0&p{Po!#wb+E zJNE3!sPWNfO;%Z1cQKs^y+|XEP4g$D={-f7f<9hBW~4UihFRB_y)G_#9gv^FI~8qu zYI8Bd18m#g8yo?b5AzFb#wE2Em>i63<3Ss4>B-N@^j%p-1A_SsqV)HdP6RfKOCdse zhV|+LIFsa>1>cD!oNaQjP$XT%IohJv&>j+=_4ni?G|a?&2*npDkN5TUu(`g0{DY8< zrRV37XK}Uw_3*}J<7z=KRc8}%IB+%Yi zd*=f+ao2dW&^1PQED(hRPl?j^0HwLjw(u${=WP&n;Mk%y+m17VItUQmZn#aMPeb^7 zQqIo3?kGK=fAGd2j?6U3NvR%l;?%-m34}uM2vEoIHf_)b;YXS=Y)CsXq;2{W(eAzC zojUJ&^Sm6}u9rgI`^M!fige)lhzXp;MoT}{=7pv7;4=V6AAm9kpu%lqLIidYr5y=K znDYX>f-Zn#rtInxnRNsV|GLxQgoVFdKECzpxEh(WCr$l>ow+3>WUa7E`B)`t#iTNgM5kxj3a&(xfJRb#^@YkFao?bSs2p>f&%M0W$lr0x< z_(^9En-4V3s8+q8?m*B!C&u<_gb8b}fclgB2WkY}Bc2y&J;@z_yXub%&e9x&C&Um9 zQW8e{zV?pr)G6`gX^yO=%<7+TBWgfoe|#j=ZeMoFl=V#jL#zf>O(du|0wM!6n&GvG zT3~3DQE%eU!K%Bo-{DBG9hUo`6L8}ppCtyd5Z)kNLr+cSFM%Dv1K_(m>_hG$2I{+3 zpA*h@9@a%8@NjDQeqAJvKkk9^t<$H;s7X??q~dZp9PGu!#UaM*GMRb+^>k6ue29ty z0%NMZ;>BvJ9VY(h7JnpfyV*G!rA?G(5Uawz1ueUrsv_dy)@hh~aEhfiK?DGC2(gO= z){d7g3-gq*{X2~=9wCJ;;L8qd*&&?_PD+a1-`m0!w86#9&@vkc#sDXK-NPpxqV5u@4yQ(}?RH!-@|mZeaT#8L_Msw7Y#* zCRDG;)j+{3UCtk8LRA(zH`*YGv zcz$R(oNfg`N<*4$_wD0v{hqp!p-!1%1r}!wqPUB?fCdQQa~L$_7Q{#bQLfXBic8*y7SRBDo2k7 zJB6VBMwWOPDzJVrOz*|Da-rjwF>xE^6SPz{Zd|_(#sys` ziIg>>Yq#y z&_-D`xxDw5x4z*awZ~nMrR4JehP>CkvWMYKcRg4)s6N4>fEz}ecOKIm!03q02%rEA zIN3KWh2X38AMm`C5)s)1gc}G7A-O|Uk)Wj|70F;4_Dltoo`cJ3k3$hQ%pGj|f?J4x zP4=pR^4Shu0~rb9S{gW9*!3&%dcVw$SbY4|a9hMD+dHr*(8?dN6`>LM)nHS&a^m50 zMz0j}4K4*N@BG>kx$`LLQbC!^n#ioJ%eNSO{cH_dK`UEZSp3$EJcQf&{m1+zd*k1v zrmn`^UH_aSpeF=IU}#hP0Xx{ACd*&b8#cLcL3%{)*skWSZK234Xd2A21KYQ(=R1jy zgk0)y?B@I_HFK8|U&im;K|H~thkIC^B-zQ{I21u--x1+NA#U>cD7^V9Gn$*|K5qD; zU}%8Y4-)bK2pEwUfA`orunI|+K*t~+j9Rbdb&itzv{!V;pF8c*Pg<-mN<5JDcXBmH z-;7A`zBu(JExpULNjF)M$z5c%rO0rR%)6Mqv1M}VYXj$xNnWoIMVm`bs#n_&2gpO( zM&tKMOaFNt%S9Qn+D^im<=s}LeEa@Pzq*0xKZ%~yfbc~V5eaeo-DEidkG_gqxHf4> z{SoW2;%Byh8+q1ZZDy_kyYrFqExEr!2ef#GRU}Wtn@!q~uYtk=jXLT6k?Q<$D?5e} z5kwd=Ma#G`#=sSp)zE78aI# z1sQfaTn*r)8$HrCEONjmrL+qFOj*Y z=>@yS4T2W$`uUalt_z}4?v**rZYeZXHh>!*(SPnn^{-zl54~ob)J7ILa@^6kAo5HX zo`2Kev_*Z|$Z(B^Egyun>u&^C4>=aq%@(NtDal&Br6j@CHEyiWpnJ=fmvN|}fMEY5 zOUR%Mun>sJ(0zij=(OzY;~$R1?(4RY@$M{fedb+|?gEY5H%Q@nK&d^1;z17G(U^ zFj_Pp{Qv#{<2s+)VE;_plE93F>zI~8t(on!xK5#={2S_;cYZ_ktk+U=@yE1W5;ebfryOJP@&wh&~_k?<~iNdt< znQ-+wsd88^K=D#P^uj|43DQb}dyMbbep}gKQI2tFFiM>I~f624cKRqWt|^ zoyu+-9zEVRws7x(?|}EW03YA(HT^>&a}r_Ulk;^&4r#+T3%&t)L4S%7>&&}uNm;2d zut*&54vnNp{`i1roe;-C43}Fq7L(NLZ<1zLER!bNK4rFyG>r#M79C@;bf3blf%!g5Z-wZFt zS3Vf5M#q86OUAj-FF{M<-#mD#_3-r-cjcK;d_thWCeRH50ynz|Mwp-9%Z7Q)&3_KT z&aweJ3H}>%>-~S%ZDw8;U8{S=Kr&z-uM!WfWXeO`{so)MmjQCY2Mr+lmQ&%z2YZ!x z*g3D>6wi7kKQ=jUYP42+%~5a@`ZaJuPSBC~d$|0Rg#Yss$WOsxqi0ijV+57nfBy6u z9=E>o_gDGn|1BnJvvUEi`|ls@c&)V|73TlnpE-&|Q+%;gS^wnj&r^W-{95$q;L)Q! z*{d}C4khmY`4hWuMU)PCG#uGog;QtG=D*6pr>rlrV|YUm%Awr1X-x;1eJt_tkdw$n zP{*mqryvo_HyEAu0NQ}T?Ca3qpa=GKc@5YX^rHYP`-t4l`c-JG>3rLadV7Dd=U6z( zA&dU&&-E3?H8<_N2ucHzAHW&vXHd@QuTUDDPLAG;!5N?zMEiYGTe}p1j3zegi0LlB z;0b95=8Fm2kn3E+n5TqF8N)u{t$BCtf^8Xc5}DpX=I-&6{DtZiST3qISmnJ)2n-GV z@!2M8C&n=V*`7Rg3e(L4?nVm=8}Sf83R#M|upYg7X}gFidXrUTtoYhBbYsZ$b6zEM zZ(uA0Uw)Lx4oDnjeh|}lU)HR)y&G2yp0eCMvN^x5 z%okQCahQ<%P0T*g=|QwZ)|r5uT6XP@QR2f8f@_!amIGhxs8%h^EGZBZnv6d%p?T$M z`%}X~RE_ff2xHXt?K;Y-k`$}&WVAmv>ehW~6qXU5tnjp;j_%BmWt`65n>|xpTYu7eEdXmteLR&oRQ#U{2IJU6l{^qPq2VO^`B{K*=beLdI zEu838w6R{o6u6C{@6t3WCvBb;(Ewj9JvT6^kG*o(IA!j}6729^>6a_8(Y^!2%irxA{!gcxDxDmf*&& zE)CvIrPc)~c(qu)6XR*^r5-u(>(?&kbx=X0!>$^AW3o^; z=rrTLa!k9K_Hm_&dZgHzSWJgl7sO}Rf7&0Unu2ELWRu(R8fr=k+$GKQyzYG6p>}w> zU_FW2(W5&A1vdim7GT-2DR7mgcV|MWC}%i+ueCJ(_`M8;+UylPcRh;PV%^;_I@32} zB*4-fQi|aw&}Q2i6r+Ch%~Ui#GI{v?{h*U|U3Ptj#)Mm}n=G9oI0neP$?^ms&%Yu2 zwN5}S4Jva(;j{v1}FFBJSB*Y#&-g=VQi{NvmJJz z5rh*#eM@xgm=y4BXly(P0SOwm82iIfW(a-IjGE$*l6AD0N`XlWQPx6gV{vEK<5#a9 zK3Wm%<4cUf(P*`U}REcnEYJ?fLi;GLve(g?}(X|)YYyoH#>$^NpbVEQ} z!C?^FFHmH~L#bcWv$JmzkRH(rRIwX0mf{iLL{+qz!O_Bw7|o79&?f*nrD$(~Dr1B? z7qu?n6Qs!z9o;B1G;0)+W3MOkC3JOkfKMC*Tg4KbS}(V4+cpZaUKSCAKS&909TZyG zg&YK}M3n-ZWW0L(68NnUDvmR+Mn8-R8`~GJqRaPjhTE9zQ{My?UuE9}aS_5q=h}es zi|yH0A>zV(Yu#t7DKy5ES688cZVg$1&ifm%)GB=YsHj#OO>=W|49}A(T~t&dTnh0( z=#WLx_?jY^pu`gV(U9Z)%b|#@I3;pX9!3H?f)uN_et@WlB>_bJ35Xig>=56e`ecFn z0LFY-B}o0)Rv2;r{^QhW^kX~>w1kqucgG$4^~+XTMrO?^KJ`eH3{ZvrP5EbHwVlUyOag+7iv!_GAN9mUdVNVoT`~Sk zuxg6`Y$zik8*gRWlS@S#-^?Ik3o|nYj>#Gb$RDdlUHV*Jt@(7>f(kzFIDh1}VQU1` zHL%?yi`~EhTVR$EG>|9Ss7@Zd|* z5f^e;k-ni^hXW>I!-#kXFh$u~Jca*8aLTJaj` zk-Wb`Tp1ZpMZT;)j2R3U{32^QhM)`=$gNRSC$4r#wM)`U0uw;~p2#8Rm z0s3)uJpihHvvtN`wY98caZ6`idiy%iB;oO($U$x@AWuk zz3k)$JPZuiPZnf(hhp$18zl;-&)j(o)243yDkCFvAFgYd*A+#$B};rr#DPh{M$3<% zZpJ|NdIg=HbtCOn#Vh2>0_mSby?~g9^^ln3yGDdt&|MSm8O1|k?~QHbcaV61valYC zU$QaV4b3VpBRFon5b$8=@WPQG!|{+1BH)xrG@9VgDOy%6KQT=}C_!AB%<-Twu0!B4 zGTDc%pp+E&ohvXJqm=4;3{!ah4m7~ws0IZ=!uZek1%eGHr%IW;oR zMDpTnlUWMFqM@sUoC7x!fEPK8M01SK8*Bdnod6k(hcz+bc^W_CF{+UBfV~&QBNpZw zqIwF@(uS~!A8{OCMWzIxz%@L2{~_jBA+*FBU~e#X0=rE%dO-SnBbgudHWEHu6Z?ry zi6z*oE)O=Y30&BoWFY1j)H(VSqKb2{7{!O=QQw5TM+RfQA;hQV<(+iY#N$C0B!`P+ zIvh!qO7o z&OoLsUv#1?LheL=Qv{tNSzrde>d`!GEk#78(u`&fR7U67Lo6!rY|%a<<uT|y>+ zG3wJih&d8ayM(?6tVE2a5Ld_|MG+BD&jx?8kyMyfFi9_p(@CbfLGHrEX#9%St|x<* z1Sot~^pg#ej@DgBiMV56&ebB{;j?Pl+6qm$83yaZ5!e)69H|h#eti~a%8B@ZR2oB7 zfq76UAjq~Lj9)`ufx;RRF8~xIrXpm~KiYUWMH-4(2iU=4hV1B3#z;q~(pmO_wYDX(5( zS&KnL70h0tKKlZZ12GkHlf}|C_zpY{&uO-Ofv!O;#RES8+uDy#k`3ATR`ui;*3clH zKIepQV%t16Ws^B3^j&1aL(sWR5&NV`@nb!t*1^~MBc{Nsppq$2%fOH7xM&XM9seHiR}lg!z-Fqj z8K6uT&~K6t06vf*R23_M_&CFFW9>Iy3DHe6(2j%dg1di&nLZP5G_Shj_D6%WpNgws z-0bqIp5OTlq+5u7)w!7})i`FLWDCTOT8rsRvbVBt4@$;iCI=K8a$D1}B10jRIGC#0 z$v|s|oErJy!4oMF+u~_mgu^Q}Kaz`z^dIy?4xCt6$O;cHE#E&U@ePqliIF>IseY$$ zEz!cZ2N+vdR~MizU{etCkh+k}SfX|m^s~aIEF$G!X@Joox{>PMp_)hj^JkImTBQ2$p8u879hUU2#(qulhTlVp)38U4qI5gH}fSYnlOHq5AqRtsiILPRAEWX0du%1x66FV&!~fFz1+p6j%I6$)j&9r7=F zuS3?>8Lj5i%Nb$U+lJweJ_!cI0pzJYD85NxMT~*UWDW#~20$>SBnDK9eb62AadTfK zlabi}Xpcdfzb}hyuZI2r7rx3NUq}(FNRU9u+b0uyWRRI86CRe}{ZOfqd0!O0=o3*u zV1)h<@-I^0UylxWko97AKmrv$9drbk9R`a_8&IxTFZ&WUw(tR?MO`bqJl98tEY}Ma zUV+uusiSwrpcz5ijq^v6+8NuT>zKEmz`IZqUAGU_1gb2{Mg=l74loI11%5sm3HfV< zrlDbm*}g@3U3-I{1T*K#}A|$l0k|6s>=(YI|&A)5hH`45gGPww^8~Dhp=TS8uK8S6+jZBse!w6I>?O4#F zFjq={52QO8qz0?}*1CY~Vn)V%O(9c#DlVMJGNAoH!9U-eu>20>8a6IdyjfWqBTEA@ zi+b6sZ`)^fSLrz7J3)#cqR7RHR;p~6AeBr%y-bL4D-&mQUG$QhYKNzDkpzjFlx3gs zM?_zGRCgGH_nLjx-SW#3;3GclVp_d7YE{6OL@h|LH$!NW$W}540~jEa#cb1tsm%2_ zY1r>b{D1M*P*k5K?;W7~@8U7|AILev6hKwQ$G zVpy-^xSK810Y&vHT+vW?T#e0=wWaa~D(Od%2T)mBe6NnUrol1$jwmcYTEqIlPT0Z^Xpb=$hdU>=2C3OR}t&a9rla2oPh0zG`Zfd+% z*Q=mZjrK!NfCi=-e@ADD_dJ?(2ujc|TxDI>IfP{}TLIlc=Hm$K}4CwU1C1l679lJ;b8|Q-oxdDcRDBK~TJ1Cl{9)hbz?n2cfBZ z_Tx7fEnXwJ3?tq&|;9`0>9F82{vT)=ua$+S%KFoHv?@0 zpf3OgKw4TfMd5NG@{fjaeBO(gsr-N&qZE)qGIB;XO5${g0x6*==u<|VPynrFj471+ zXbd5|A)8x?=m;uFkRFzP?L-oe&0pr|V97qNwr7mA)MOeG1w4^-2P%vCV|6AQRPAA5 zbl485LJUYiyQQF{fL0XK&S|NX!OA>jMJr-B`t79Sv_PM%w2>wftFoW2d*JHqj4^tk z8WKUt#9U(hpUDPtGSU11k|vI%_?>Y0sD_>>Zhwm=V=DsFDl#lju1wUfk#c$kWo_n? z$+$6LsD1DS#LmsbXb1DU35aEHgR>$v4MeFxbx>;Fii>+X_zubSIDRmYGbB~AHLho$ z@BDi7`G{M=sOf(H{!M-^WM46uTOa)O>nGL}wS1l;JzeN!FMs4mBrZH1fcb$E438E%)~KoZ#iIPhHkyr)gqPA zw_+8MQY5d&$3;Cxr@KgvB`^y-iG|*XanOGPJq#t?qn;k^l;nja9|g=1<3U1mQ4Q^X z8IW<@pneR4R$#UT;RclxnFYj}sh2F5F!mNCP(u_(@Zn>-f*$rZur$1)n-v!TkpLk@ zj)3q3GU6aud1IgDV$T|dh9S>pjROEg-tUYyEW-+C9S9G9*7L?3)Mm!(gkh^Ao3V3? z4vBuS`CEn6F*St$M=uhg_C0@$kuKmLCIfKJ*qar32tsH);@9BSqHMuP4hrhx4{RG~ z@UKb3g@c7^Ef15!K&0NOrART;llPf4z7$+sJ7f9%Wk zpzfH_wvLVt0|O~c)Wf3m@zd4LKCp$j6%mniSEkFov8AN~kgfQo2vkbo2J-DU1eB8r zEtn4gSbtwrGu-#R6HY0>5Sv;v&>KTa?d0l8tU`|*33@&QkTyL*JA>8rv;KPK=1zgzUoZ^^fL``vS@v z`yQqZ$(*rzfMg5sxveC`yiz2>&#{SiY9z#LI1 zI81Xwj>*Xi;<<{U-bYWLnnArus#Qg;&j)}og6%*w z0YHg#4BQdC+G$5t>=1%WnG(`j)7gr?38v?hZ?5wG^_j$k@U**@HZz zq_N9;-2;6Il`k1jP(tKKW+Dg)R2B>a>Q`Ma8psZG>?M|bd$V3A7e_@!N%TBY_!pAy zrxm=ffJ(Vm%Qi4ZcI0SD)Tk2nM1k+WN~K-ayFZEh$T^yW=))yk#<+p4O^82GJv; z4~&Uz92-nUfW0Rjze$Q6(s&@$)t-k*p;-CTp>}Pwr)EeoWFC?%$|0q2O$}K46g*jo zUx>6M8}c(z`QB&NUx{a5_tG!d&(o*YKvqWsT^P8+*VeEO;sZ$X7;#>G@Sq0eFtOM8$HFNa_rM3fI@Ys@1%N&f`B5K1Kfd`*cUA%D0VpY0 zD8zA^AwaATqWbQXulMO@CPq6zfFM6&I1F}P7z(#QQDN|~S5e#;(F%dR^@IdYn*$6^ zQ6`Y374!HDBPu(=x#qXdR<$oTs?ZF!rjYEEVZ_M21+gAGdL+vja=z6%iZ$d{s%HCzhEz2vzF)_7Rq56pT( zwllA5gz+kBLGU$Dv52^h_w+?^-IN@k*br#J!{aR*cHnsxS;&Bc^4ByE5@hHqVY+-C zPYGXyHV1%DcfWMT@oAKS-X${yS!Br-E-H>7jOh+C-8F{M4fdSRTyds8h1vmvF(?q_ zms0tISf-QHI@xHKq+W(CzhG-g^_gj3V3*<3JVifkj64vtW84yODRFuMjuw6Ux06&W zZu$dWb?F^R!}T%Ek6(SJ2Q-WUky-uqc+CL#P!&8Xh5(f8b^=&U1|>1069JF|P!q;6 z2=fbfPoY;MA{sS^TLexpUq`Lp3^MFVcZ&ylE^w-HiI#AhODyOz6*-^R@y9vs3Ejt( zvGtaY!$9woa&m}6VYu(23861Q4MF-`oa{j3hdEIcfCRXP-*TPYtRdn!LKnd%z+{od zM@YG3dspk2&ogVGiVOQ|1BGI@|2x%>dF|%d-e7Xd+R_qx*(2lQHBlp?-1>Dp_fYM7 zu0Um6ElJxW;a7hyD{#U&1w8>jbJ;}+3{xb>Md@lO1(cr`ynBnS7;4M{_HmFc(T&>T z85_xVSSbpQ4K(6?_W{)I+==iM!K*Ghv)cb4sutv6^+;R{=zXDD2aYdz=D9G&sUdwX z0_R;cUbenK>;%YyvaM?ajz*F_pu`vsgb&2X2( zJkc`jxjEZdUt+T{;V^*70Ieb0KylIS;0{OK#;f+hg79Op9 zOv56n zgv~KX*c+rI%h?vY8TXhfXc;gIMDK{|s2Tx>%r3kJ2#eYYO?!dV8o7nj+ZoOQMZt|{ zQqd-bFofIiWt;HlTD$@?$fc0jj}6UVF+=$mMgqb~fKos(5cd)$1EN{*qQoSboo?2F z=6!AmuR73FAY|AsnOI!B6w>Pde;U=&`y6|hVw2NCG9Ep8+doNQGg2~RwxqyG-8v%$ z#XFT;Wj#tq!HUDOoMUElGSWleF0xW@q@hAZfUF6=6!S7^I8NxO5O48M;Z<;#6p*ru zzFS9iH%|nWS$6Y`Eb#2cA>$5E!>l>xxk{4W1*iX(owu> z23jDKFr$+SqOen>^& zEiqh&f9wWRc%V)6?bygsKj1c$mblU?2N&(o$1yse596<7eGf@09%Q)+BB(Gpg%pNh$KxI51gF6w{-!IzQ z?qUhn{=0>O>{zX9P3RnLJgCv4W?h$%lF$qVxI z82VjQ5g>;}9ax}aqLibnfQwNM5Cf2MsFE9+-W#EJL9lapdnvc$B6d|^5#m+oJ0Z&X z%(mgW5-1zuKZ2dzY@xX@My&r>cu|-&-;x=%cvas&`j!kSf104roCWn*sER(n@n<7y zCg9W-xHLpB4P4$j>^vZ$OE0=A;dNF=e4#J}2*#=(N+@o6+3{8ASipSXsfT0E!n%N9 zkZ3Vc%K<1O4jpi-B6)hQVCJ*!w=LrNHFFT!0qd&=b z@VrAeYGPuN-~*WydLe?PKypIf#wwzHLYO_99yE095=y66uU@q!oD+Uj3y)ed$WOd^ z=YJk}qt3gj7ZzkaFdh{I$5fD}07(2O7^rVEWyyhhbHr**+_`6KcTTBA8apysNhvLk4RfbX2r$Wh^5QoT-D#Hd3do{^SlGx zN{R~vP2ywdS~AHI@4M^{M2sl#3F8QUgp|NE`6)XTSCLUYR7~3>E-J!P%CeCMrR7!5 z^6&W7^G~+HX$fE^ez`I%Kq5x5yG|_v8jiPrcTF=LenQw-p(((Uf(4t19;7pV`RWy{ zdF|g=CD7njnnRR|C`jfVp|Q5KwKd24L{&#D*Y4~-S@)s%6rup&=C@Sz{{{9r|(UP0X{$z3_HDpPP`t%K(ub8la@(YMp#%JqC9I z(?%@}s>RiHcB_#u+)$u~MEIHe+*YEBd3tC%Q9_WpQoIRxURI3Fl6Jpm+5=rhAI9lP zi4N`PDx6_JMlf>%@B#nWz^XpHP6|2C({KhtB?m&z9w?my#u@VEU_{4`(5A(WyE%?12Sx==zz%$Cf`k%I{Y0u&;d(7?@SRv{J$&~1Jz9DI&TfMeN% zj>Qz9AF3ETRPXjc5oR(HGA`?2@seixu(O=AkA%254q_e3H?qyDdwJ%?;K0CkIk`Ca zzUwv{YjY&{ygBRG^cXM}uqxu5giwkH08g-LJV1Sr{IC;JzKVk4e=hRem*1(*SCFw$ zj1%bvP8IPUBv=w);Z32c@Z12-NGV6zPhK3-JhUaao1~Y4uqTGX6$lLK2J50v1O>;+ zBC*8Si7Vz#20#yq)D%3UfMMY+w50$cR8aRq1pp-atLK=U{Nnf_bhc3Zkv_?>PAY)> z1$tpW@E9DW>i6%t|8WV0gnk|@i7M>CqIlHOn;F+#C;K+>GQUIFjk=$d_>$~6+T%dU zp_TJYRr~EGD+@vjG?g;b#%`oZBs8vDrQ?>UDWUQrW9NVgZXt6LgCZb2;F0ledK9GT ztA=*Ps)fwz*}K8|79NeI$tuD5OiQSL_7PwkcFZun0X6p=Z> z>;trcQwZ6wfCvk)0vskWp+y;i=R|xBgbI5o1vef?+1hJY*w+5Y`3na6kU*1BL?$ea zSQz@u|G-H`hydvdYVtnNQ!6W*Q^YIR)qoe~ek}S1yY}d)G{W9^lW^=|cCW8&Fk7|QKR}caXIn5tlMEG*zID)nB zxHOvSj*62Ixgt~5qU%ux!V8%ikrfe!c{^w^An7RQ2Zx5fp?OJR(P%ikE$mWF2i5@q zk0NI0&{>dA2|zWTye0%TUiAyu4I(@KmMVtk-Q#D6?xNl$=CCldts>ka;tZ4riH{!# zVxANh;UUd~rkM61TQcxp;;kbwVcOabaZ4CB3#c^74mRNaPzXTLMTESBg#%I*^SqG& zsD!u$fQkP(OL(&acOZ(_llcjpwA4C%>bOd00=z*1IIAG15$QJy1Hia=&QK6Q_K}Bz z#26Ne_y;I|mSz$F3cxLpY?TCYifyOIiAfz&PDg=F^o+iWHoojTHV8EH6o#or@{)0H z@K3Hic9JkpBhekbu_Ja>5?fOsqNGlS9|*(?IKb^ij_E%cq&zl+Ydkp6V3F{Ds^Y-X zQJ0n9dC%VYu1CRwV2An{qyzDRgEIhVbO^!_PKf}E$cX{GO0v^|9s?zQ4$&MOy(Qa5 za2Pn-7)DTpdgC(4b1v=a<_f%JR5iFipqD|N!y_HycDUGFMTt#(2#6vFA#>O5wdlQY zLom(S_Pf@?%nVM{I5^2LekNOn;9`2TSpd%o+bgf*Xpv!L^6;SRfb<2RGKYQ?v8>mU zouDR+7=ll6h!rw73RI!j<4PJKi)0Zq<^@5DPJ~n>wn*R1{P8iHPZ-Vy_6)6h5Dua{aFs073{65hzX?LdwD3q`^TCVibPn z3y^JA=c2HV`T3{@<=UpeZ0Hhwz>Hd57^V_$f?Gj0$_{Yi02q)DBJbX9skyZX&I8T( zKFgPKcr?ft#DoL6mSErj8x*yMFv44d1CI4P-|#&8@jMJ-b{A-%yM>YLQ;kSoOvFnQ zoxw5~HS9;LE7U-2MvM1U!#HaTIv|W8kcD7Gp#~*)cxearxizY6A-2yr#jueLO>1+M zK<^IR1)<6hS-x)qj;^>$WC;owFVgBF%ixSd??8osj;aP_dU0z%3^b6Iw6V$o%Sq5A z-j(q6mQP?QCQl!;JXo*-6?G_D5c11`0S*U!Nb|HFwuA^PSQH9XR0yWRU^D}X%%xcL z{+z^sN+Z)alp1z1pIYw>U4fewyQJL}1+ZI-xX(MGs(fo=BXXYzhtbj?WZHP@;Vhvg z082thX(Dw*PCbP4hRg`rwkR~7=ymjp98%Mk5U|k^?t^#`#}#{nRIvsL9?IpgI0Nng zrUJ02yk+VJG}_3!%($Db>x)b?@I3H8yy9j6))3gLKP= zfEl$m7V#~-hC6LyBI>J8d%JFV{Dz--E@GmozI@-w_Ml{j4K&0m2wx6Tgbd8U9fLSy z0kX#(n3gWcio&rbLUAN%@C0~X5C#(dT+l4~^D^MlyO+I(&k#{3!5E4E4n0>`^B`Ff zK$=IC*%0w8u%0r(5HKW%7{tLvP5_Ep0$Zp_u!srbjp$XwutWulbP!0Oei0ZkK}H2; zn9%RMo6aNMszqMoRVU7h=#eDZ2{!_y=#nh2dQ(U{jw=)>mfoy^$`JID7=mEKKo>yP zv!U8XUvtvfSaUl=Vs7pUnN)A->f8{4nniRr=B0kCiNuf8JV*y8QRtPP>p?j~&I>?E zoXO1lx*J9J|>laH`-e5biB)*68hhVT)p>Bg?Er_*P} z*mg3n)5ah&kt~5D#|X>Wv-MCk7f8fC&I55Nn46rLxdG`%mNOj{4uUvZcgXjb!fp%H z5c?1m>6nzjpJ;iM1ziUaIZ+KYHBnUJ(T|D|KXJlx23@9cVj^auXAWyZ9?9z`C7fiwt1U2ce zncm#Gh@gY;WfALeEXoYHA4(UZI0ZTdL<}80AO|cLWW&RS^f9;xJsUBmYBrX3p%;Bn zOQ4(E@0AZ7@=tTUjg@Oih<{{L!cs9zLmDD6wg|KPh$#S(k-$@MJ8+4iB8S44I7i0Z zE`0NRcTn;mtvGfmOjL2LXLS0XYMop^&yp!RhPm+LnbxwRY@soU7kyPV)FY8qu%_83 z%R=OS-2ktf@4G(xvfga-?bzB1Oy;6Qr~Oz;ly+!hB=0{AP_dIjn9CKkIYgj@+Q_5W zxY-?Oo_C0T#{REZX-0%jz&?T8mP40;RWlTt>64j5%+MoaUZtt%k*W<+kT5HV&Z(hr zH(5=7=Ap@#Xw@;U1Ox6rv;{zBL8--VmX4!)(h_8x#m7J!n?I)YE{E?Cn18Y)|1zL%q?`I!Z;YQk~a!#uC3 z+vj*in@P?+b3>h~CvPM+-brbDHZ*X2EG4a2`cEDJE_`Wu?*LO#N?jI4cVN7i;G7i| zPR4_1G)5LKzF6B(IcQlG?Gdx0JgdzCrNdLLg_UiOJdXk1G|uZBjfq$sa0gei@KEK% z>*JU3r?^-q9b(o+0l~ z12WG0Fm1W?D;hlIUnS(P7o3s2AZCs;r?$mL@dlm~ZnrTE!$Kh_s6D0w_gFrYU%|V` z(HkioUTlB5KqNEvxHm2cW~~tH?`5=JZ4ChtUyqR~(?5QeR4)dGv>BsWBn` zTBD_GQr$J_pG)V3S+O~6DLzr*%&4@)i#Kd$4@+E=9Wu_(wEu~-G-f~Z%XC(XrDZ%b z?T$2~^%0&2wGxS*aeJzsb={tR5JNfOHr7dor+OQItUS0xsad2^wRhGk>DPX6mhKH) ze4R&O7WbSkyyR!#@irm54}G7yMSGg>W!ZhRfP9IUo4dELO_X)R(yzlb{^u3hWQLY~ zoxWE;o+*Fd8vXCewm)KJ?V9RFudiB}bPs*__Tqkl34%tdsj)FB?wEXsRX5bwXgk5r zf%e5T%{)OlkY;=;zx5|puj-Pew(7II zx@f`$>O87zV||})9_D7J&RhMFuYM0Ct^e=)lt!;B7Tj7q*sQT%i2UQy#>M*7w||n8 z-<9*y#UFEMl;TggDf;i9>&4kPI6XJy()_3Kh?8p4PFz!Jzj~V9Rr6)r{WZfiFH?U@ za5npYKjLpN3Jcb}3?Pl_8$&vI9r7|$3C^pj)BS(H=lWN$@~b~k<_Cckz9AbU#qZV1 z+N+ow1{oaKms0o{0*}5kUHwEm{JouZNy$B4rm&iO8>K(eAZ#Y{Cr(YaXaxLP|0Gc4 zmV}k#*!nNFrGxmIyS&MfeieyT7a=vOXG^u%pS%z$@NPxb(^Tc4>ql|V=fXJwOD_hp z>Q|q=c*l69@$m{1e)vA-^hvpt`jUW#x3U~}Oni-J+z*nONIxwud`!vuNeN}b4(@i%+nnWUOR!6E%E^5NFq zb<`>p!5!Bq)h1R9n-f>+IVEltFAPjCCy2MN#NN8ACGYwA=YV3?uy0pQke}BlPKtFN zadhQtXE)z6-FiQM7^at;7IpENKGI!QzqUwc-#aGfsbeTzAhNRayVi_e!C2S}m(3~% zbtUIxWy~{~Q~NpBRgJ;mKAx&LVOx;PD3f!6?$1XOaU(?&2|kK@EhN?N8=KvY7f(4b zAiukoFP*QhE^WvA1lb2+(w7`^-F=+&SFe`i^Z8nq<2i4Y$)}>KE3Z%WVWPir;=>ci z?{33}scn+pxh@;|?lTpa9F8?#x}#O8yCyxq@`qMs;L1n8+jTn%Ry@)^4AvIc?_DYw zJk>v}s6A`NvDAI|X|sBG8&_U{dai@@g0aHzaC=cY|2hv%`Ps0p@}BRLYOnabejk(d zIUAe#`n{=AdP=|NjW52Pn^m}<@yjj5FDrCs-R&FktorpMy=9m4h-AUo?*>o#*pd5< zXQdj>OIGy1l&*W(e6MflwJ__{io0RAR&S&rm7kGgH5*BUFN4*R}Y-J>|l_;t>8 zb&!B>k|vG%y4VBe!YM%x4a3d`51gdiPdi@owSA-*HfwuQKuFYkk>Nwl@AflYZy&6D zT`013JQf+Td3Ce)*|PdwL6r6?+;Z>L4U+v$qdC+g2FI^kT2A(@zxHRTbbHOZU3r1A z%v-E%*Qol(?-=zlkbX5tcUXApWLqJxiO2)~sX}#!j%@v7)aCoxI+w&;sc-QP7G4Ue z5n;*OdO4X&Uw2@Lj&WU;ruzu5{w_Ybi21w^TXpZ%+B|tNabCr!p8k>XQWC>$k6n@P z$2;Jr!>q4+%UN1hUDMEIy=rm%aWT!4hp(uK-I$P6pwZcp|B(OQ)SugIKQ*_kW*yP! z^$SeBma^aY>@1@o$8w%eg4TL<-RfiZ`xj!9+68yDUX*(KYpC*M*yrK!_=2@Bj2t<; zl4zgE&)t+O5Y(&PwQB>%yLac^M-)E2d-Z^xf$p%D)w7Jso+(~Fd|KP|yXSb1I37N5 zse~cIGQVx|lrHlob_qk5nzdQO^s2Pf<;BLqH6jwhtF@B-KPk)T-i}BMTg7oz_-R}h z>$G?1mPmn=1FWpy)ptpi-srSSkzr1)ZPPgPAVH0{)k5Kjrro)`dzua5>*B65B%K`C zp{_fgwkzATjyLABcc$J7WK=IJ47DDHn9UYVlVY$Nrm~Hl+%zJ2kwqzj-{|VNJw5gABU-KVQ_Nci)iq-! z)*D?F4!tCgFLB~uNhqETdDi#3-rM)2%b2H0$}Alv_1YO@<)BUK+Q}X{-=f+B1v{nG zZ9OzUd~gn{4`oV_3>!6-ih7cH;Y{tRox8P1icjC#e)DT*zQ?uBs&kakCpw%j!;Y6 z%DOfEse)fp_n!r)Lx~2belH5Gs!U#cA@Eae=T$3TYFW+SN}OIoGP%_LCi`E~>lFKn4j$0eF|(9aZBx9KWBDzrzduWK1NU*}>lYIO`~>#t27hSK3;Eh> zf1@&aB}sHxmx1!Nkc@FOl}7Wxwm{9IYxicAt_!>w3wyJZn=z1&HS}H1qjen=TRsHL z{WLW&u@sQrdaiK;RdB|Md!yU~&fi?xS;P6ZvIRHpXDuxZl(VMUynVzmg2}GWQAT)! z|5fctub`bLUc@n_Q5o^@4OCid?{|u;yP8AqT|!G?)pNf#SX<|A*I}XTv5N~EP0VSp zwj`f8eNRP4n>C^Hs*3n+4H{LR5gF&}B}r;Ozl!D17geu)N$VENdZV;QFJxwU|F^*J zm&V$e-m6^aR-i2HQM;pjm~EY~=HWRza(C0x{p;c1SoGtu#0n#hgDL z6|ExKljbX_qD z9(X%4q|_?-;ChDI?Msy7h3v1LdnjK@6lkpNO3ZnA(=+Y;DOS4;fCz8^?`tIgk z%eHqi$Kcdwj~lFPq02`&epB8|6X(=3sLyb3xx-@Go8QTx#1ObsvaP9>zFSuLOnu^g z#%OoV0~@|oFMjo2Z5hkZP;dec%lqh{>o{vRBC0|OkT4jU^@9XY@c^COG*8=~r- zy&JPne6;SRZNpc&>aZQPuM1|`CzKsNm^Co?T{Ax(%W-Zph=wg=9c%C|HKXh~8qW^CSewXy z;BAb`v~r)i3%gB{g=NUuvFxd*N!$Ex)mMCQFL3Lk6LQ_?Fn8$GNEaRD)f?~r{0I#R z<12J4y1hfNSkmnNjUca9L7R7;i<`LEZMZqLy6oviuPN#X3S8#z{kn6Sx&OMIbml|x z(7JfWx&oKUpYfvkGaEGO0`-nqOn$JAs(&1l8=rMOZQF=+)5pDBvC5{oFK)*i)Hl8r zdUTs*rVuZ?QwsmO9VXh+9swQpbChG8FK_ZZll~yy^>)}k_q;vF-hN#sI*)7v0n3+l zJ=KzHv z%zWv$Dn6jyfa&^CDii*dib;{jTR*V`xS!M6F}-;;wYP0kjAnNas|;O2e@k4tI@>sj^m!}2bMhbqog{U%4fUV z>)piSA6H+!W}cm&bWdZig|_j%&>E{9_m4Dx-v9PFl|_fg9M6p;t;b^O>-D6b#1;e( z@$aSWb(T?275@{kBewWUd6ZGQG9fD;81|mK4#n>EAxa zRr2=P?dY5@W%Fb9pXb_(C-wHfV0qIPSdo11&Au9TDewFK{W9vi6uh=Cd2FaSDVyy^ zW5e?4xG{6(U^W#)=uS)9;U7gW!tIui&<{BuWY9REwj;+ybG2)27Wdws-jAua#-BZ| zU5%J{V4lIZvC!3L^Kg{J@mQvJws{Ppd{!k*HBoVVQg+Wcrel1Buf0vqkS%oB6>On% zYs-vDUd*u7In!9Vs(pAjOz%trJ59eII;lgUby@EkgWoKDyh%;G=5O^`jMBG;Yig7+Y<;kOUrp7Yih_?%_j1RzcDx+h&`Ecod;Ahp=}R0}P6N5%p-96o znlh2pJ0e)C>R(JYUpidWBqRGnKdDpJ?QmuN=#u2qH_~m-xT)!0g$A6i zPqR{8J8DF==v!YLF7WPqOv48Mm_>H^mAMa{x@H=}QlrwJKe|3=PZSyGb>-C8tR5@* zQF36SyY`Z|*I0pE2;T|X^IsV`4LYjJDjmuw=yijS7&H%Zu$y(iXYrY9WfY!zUm(dS zX`p-hWAV}9$NV2EtzJ+Hi?RH8`<})4vr7nG*qQi+_xW{xFO@ir4kTP@dmkZsdai1V z{M)fY=T$3zIB#wf!JACXV4Oi)S`;HE7E{Wiv*vHmIP2y^& zS6JhyQuKCiRoB%MNxA17JoHGYw!-QKg4C91tYTN-4Td-v9tsEBwEzs*S7_Wj0dubI-+kCz|4 zH9whS`sm?FtGHRZc$1Z1qXRKFE)@T;Nk>XLX2RAyr&Zy!>U+wKW277ZcEt4j+9^sQ>6qg$dfL|ZxbeaC3gr2O8W0!6VJI|VyM_cN-roQ>B zsHwSgni!{Hm$Hhz1dGo6o4F4w#%hs8<1e@*>+i<*9+N8U`K99g=`_pNT#YoQ;|Fz2 z1u|q0rmWZmtnIHDlYPndrm9Zs`N)qJBa7^=%7zs)Mh)wWdWKbF^4f7%*-}%oq4D9u+=dQIj z9lLOkuKxwwZpNQC>v$HAM!sjWL||z7(DyN*I}YrybT}}%9s{Wx*p{pC{iQ+p3@X*Ir*RD&x0|9$`j;eD+)G>U^>wTvs)ItgoAv)@RWl5sXPt)r<~C2G1yN zF(>f|Egm1{H0>0M=Y8+?b54Zk!?YiGvuQm8>5%xcQq(H*hQm@CE?Y~51diWf{$QEh zH_6>pb%kvucMeYcYq{S0_rRk?c=f#KdA-o(iuSClOUD?ZP-EJ}T*QkEX{*}YeGAVr zPhC--KXO%5Budb_1(d3bN=jDw$P(DrYOP2(w?GmN*aAfMZ)T-+FZ`LbugL48cIcB+ zUm`x776`%=sB5mA}GX6xw?>Dq$~0P z+fq+!aN;bPj`BiXJ5kBp{f&Q1aNemv1Re?iFsOV#_4}R~2>Ngl_zK|npF77>-F(zY zmSWyIpvm>YeQzmLS6Qwsr48@;N3rDayRv_-g;f}h*t)GPbiFPw-7Q*AdmegYQstGo zfNhl=@~~D|kuF+DsGa(aX$}pk6(SaSErry<=Kr)H@%<76m5L|1n$t51S%tQbc8uf2 zE8s&_O0AuKHNv|ECsK;LS|#teFFR5iOuWovf7RAzYS$qMYClN6&Is52XiO>#SEq+x zGp{P)5ttYR7Ez_BIS~ovP;Lh&DqkO#%ZfF5leI#uQa!n<%rpqE8W4H3h&3Rj2g2;< zdv12@82IQa)*hB(CSE&iRyss*@pC+c01MqSmz9BUu#I{nzA@INSv9iAz%?*0WOp3t z?Q|+a$%jc&edDqti8`u94uFsIvQdfPryr^rAWPl7+-~=;yhU#1KLl+ zy%1k5LMwucpI*0r{iL~itLzFeB@~R~)$H|^H9w@H|nMpf4rRpcyPe$I*#wD3tqaPH39{u!_*lW~=^v^>x{S5M_L zggT6lG>$=oFrqRrd@hJXrj@Si=_Cs)Mq0yy#noTm>*|b=+K@Xfa8X8oUZI3X(x1N= z+3czz*!Qvc?%XEpwe161gYDgHyVd2zig(D9!(5vxSJ|7J)*sAcjKX`n2JH4QFQxMu zaBUJF(s=omUVG8k@7ZwHY#s>G z1YPG2WO{3h_-%dyde#iKgoQ}#x)lH~3P9pf9kqEiuY00lL!bE9-n&bg))_p+Jf?DV z7ewpPPvRPHMm?t;TEmmlo8wxXaV~P37Q6KZjoI_nWXNea1m=F_Y(5oN0_T3vn1C~0 z>|rhk7n0RF_#@M7HYV2xfg9TXW%BJ?qmJS}-cOCULysw13)*DRBqrj$Z2KjFGInV$ zj!L+sc*n(9E>@|7nb&2_cgL2a0F^+4Bp%X4UQC@a8>7bjUlu^q>UC`>Hx7!DmGl@G%Eju*(8c-xp-{@Z-MI6mGjiCHhHodVBdyeUa4;i;y6X&pJCa=Ckl=;pPtmV};#;%$%@bC4{OxPwRR~KT;-*@D`dIaJA zp72eZ!47sCn$xm8iY1|5#{;<`LJ~Q#s7hrXW-|6zRIL9-)BkyQMFtE)U+ggq?XXa` za3;H(sppWGZckvit#92wyvuf==}BNQO2wlTS)H0nszTdZatdO=F4iKn4*21SFrWSl zt3jM-^dwjqIUb8{qK;RYDNQEzT>d>z0YQPu{3ieIUdmTQn9%*!FCwHF^2KzGOesIh zZT9U3uFYaguMllW*U?(Lca}0NHnYI2ef>gJ{hi98$&bj8F9iNZ@P};*=Zq{t+3_UZ z!?E*bZwttj-GJ4*S<|R=wADwX@*`*`e`0PRVR(A>d}6OltN?Y_ZA10mu0lK!o&P#s z?Q-{IORcv)(D9E;=nKlj7Z|PNBB|PuukWBTUzKeiUQiAy`^D2jzw{0ZwXtB5yArvc zU`-feslK;6Y;>Z9bgdyjGv(Tj6RVekCVOk}bQqv$%xFFjCkMq7m298p#~3}RJLO${ z#PNgX>}m3RP028=wYOZCBbYQte(3s9(743Uhe`NB`I?AI2u@E$OXZ6{E^*s=1)Z^! z9qCzQtby4tr7^O!8On^R@Y|QE3#=X1@d;MQ4;)L}gOeTpkTT;7RQKNv$gdXK)~wK>@#wL;l$6m8He|EG^u-TBXr4$i?~7q``Tu33(W; z=&UE#k3^oM*~uR~YkQvw{^261aL>d}b0N?MCppmS1_>ulq(`2Z{LJiL^b+-ET?1s|VNF$Z~qVPyae@ zW1z3tT3_t{rxqQdg;lTd7SYoA@anz<&CeRpNAdb>x0TA0*&F4VQE0zI(~LOf_IcMMAKR*yt9~j2id>=ecTS5@EUR!) z9O_md>itkN01}uy&g!T(-o(?~T(rvRs%vJB14KO=dbADP*ojEIGRr{I#dm4&y}^mT z>@z@J^8SY*5tjpy80YCZl!JvdP7&hb-OxKS>Zmujtq*%8G82`owB%~_mc!1?C&H*b zS3@89s?%j8n)B2g^p&T3NUPR^nKK-DLQKKMn#S3|{Z`u_eZ&KvYr40z9^n!>n#%3# zLsSwR^##8lk@^2Kh)$wIO0?bf%rgw#+|9Q(4I$YO8$lheF>R2qbSSsD&SV6d@KsM_ zhT#b*ZG!xEZ8-pR2ZvHrc(jND1NG>6GZg%}-xL3L3IDxb>Wl)opG&b_yyNQ$qJCT0 z5BAAzZd~@lW^mf04v))$3wyA+hMN>eDOaA{Sw@_Uwt#P23Qf0r{KCeYF_&i%m$X(f z-RYWrEmu?nTz=OoUO?{%93Gv&o3i#$L6<7j@V`AXyUnx9yYfVsiC=20%x(Q!4Ejhe zzWL=F|6R^oe6=~f)Nx09LuqXVoVe<&>Fe{nK0G0oNmFMPY6{&6S+rokVfM?M#ak`` z;m8fFQQz<|^z+V?XqLgfO&ot=on84r~WHbWc=<7V$o(2Hl2nK=*Qg zdQvx8db|CNyOsJN+tz#O6BNS|+WqbYK(j(^wdxNSke@lgrXkzMYJy_H}>rGT<(o>u8r*Xf& zOeJGf+pY4{bupe4tmZ-`w>(!MVl81owRyQ;HM(A>-hdZl(zH_qJzKNHA7N=VDZkUD z_4(5cj7A73i0K1kbLo6_LSG6vyPk6RorLW9wF1#-3(&i6eQ@D9c<|{mf5ls5ST>C5 zwcDT&NA)%PS$la8qXTNL76}qimxI4jVGYiMU=Ql{fy)CN3x)6Xv;#v#4a}3%--|o* ztW8$?A7bG_kb`Sn@TdJCA)Tp;;wIPlnpb<0uXg1^Saf)a6V~SjBLLK&YnzI4JhMDb zAQovh@OiGu!6mr;6*B4ON27B1d2mi>(3`GL0=KMj%EbR#`MYR0vKN@+9gKoSKB z!AQLWrl7tBBc>U5#AEm7z@El#p8lTCEg$uh zW9XGAh1@eXUsfte%aSuhn3R^@vyNJjLoyh9GXK9)d-p#|ZPyp=?%WzPg!R=zhUHz4pmO?7FY6e=h-JZtTt!W~jZJ5?5wbQ^V+%}enk z!MEhNW7>{PU%DQAB5|bZ3UNtNnw#WE=JK*NB7oBsdKC1pqZp@34B5#vJvB^{H3ifo zqhSNC^LE=Jh_?5f9@yrR%AoR1T)TVqro9{v{Qj=yp4$WTV6ao73lpsj@2H3B)`Z*$xolnvmGdU z*tiED_`W7v(UxJv5H0V8PWaoLQxq3lhe*>4E;D`=yu+w}X0DI^U9Q^veh+wa9_3J? zvqvdU&PHz?#eH9BU=f?wRW*e0;2wHJx!3uYB!c0i=7BA6YBZBnIQsivdcSga+`lN; zlYMEV!H!8HxN5Q zX-=+$Hg&ijB1GsRCje}&HiEd z&=S*EAFO^2@42I>0m}tN#R6SA@N7zG2!HH3^x;hB!xEIAqa*;}u2l2L{$bKtwJJ31M)W93C z(jbj1^}g=$W*jgI<;Z!%R+`f$A|7f+101`-YC1M|T(fQ&IJPQ^k#a?CiTIsbk|-eS zW}SYW6P$mnSs@U(l4uX8Vxqz3zK_IMsKupt?7^yCTX`^HF7CvSI_BlmFHK4C0*}0 z*#7oM4cUf$U4@p?j)#L4uwpVPuNEV7OOOez9?B+gx7=4qNu!(_i_P5{h)lg;8V%sDeBOM^vDiKjEiC|L1rDoD&j9Q7Jqo!r zo~?EGC(GHsb1=yR0;uDw8>)}0i3Bl9TVRc73sshk)odjzxgAE?2hK?VUIwkQhoc`%AjG9*X z&{9X`&Er(_TRM$IlD^Y=ZXWXeEcWw*k3KY9GxJRMXLdd;>{y!Fc``;2w(cb#9KeSN z$F7iJAXam2YYp~{Gi;-qapo29R>{MZ_()=bx~6r2itbH$gB90sY!CY^v+pj2H;)B0 z3gAK%_)Ppy*1Yisi7iyt~Yg6R%ow6c2itj?ySsbQ#5bgqFYC?kr53anE_#X723YH9SY*3xU%% z)9~Yj%kk@{6kL_n=dWeP@MtHJNjRdQ-1YI9p51D9cFZ9)P~d_={k(CA1_j}sZgLPs zUSAYCh%-{B;De}zO}t#N86F@Ldqwv_V_hcgY<51T6h7G+YbCY`BB@wTh*4|c#P zSpLQ|Z+JveTo^!@rrjD=t)gy146sBQek7kb<1N&4>V7u#kU&`Y(4~)&NSgz-B{Bp_ z6o*ZYfQKl|%cUP*@hVvQ6S@bLP+a9mrIBHP(!W`~k}s`1;?i8jmrq%bUdKJil;s1V zQxX^rej=^@ZvZv1jBRQsc=ydGXbyIiZ7oEr4OQ0rC@1(tWFtsB0dT+iX0XRF`a=+@ zZ58XiVKal#V!cS6KL*yXDjP`8Kf1aZK@4B8(a(>%6&KHy(W z^^lt7C1uZWMYN@T=W6QKj^LNBp#EURY4Fex zy|zwGqFIt#R3E_7vtBU#ND|`5Hg=7$hqMm-`YO@4l#ZwUcCD?Q*}t$1aVVM|;4vhP zEg;|ROrv+J{NxvN&I#n9LYYPC2Jte@o7jo@p3Gy*Qxf)x>7NYfM<=S;0Q*AD`W216 z%RNxE)S&Ji{_BF^2Y-PcL{DA~6QjP9L_AdXTU8xO3Q4?0EJ-mE%VmGS_69g#>Y3R; z*+Naqjl~wIh723Fb|_GQ!ak~;BIJY4M%q@T636u7_@zg_ne4r7`~^MS{dg02R;5ru&~d$mWz9 zbXczuPXIuc`uyhjPV1Fg7Ovt4It*8Y|A6wUM>;=l_Aq0X$8vXPM^ax09+Av4a9}>s zmZ_k8Kq|rhG99FNW^5RTL@D5rkasQ+G467_7hhDe>mrn%Ktn5jNSCVYx+Tf<+&`N2 z(mDOv)(RSy&%qx;uT)YJk2e9H8dJwq1bu{t$reTrY9{Z8GmoJIAiCq<#+wZ>Cu)Nd z+s9+^bp>rGHQI{eEnS@QAqMURoI-ilU) z73qnHTWsu$20*Pgs4y==@jwX=_a%eLBw@g%Eb=;J4BTHAtB$8H3dbnC4|F>0m59oW zq0g{(hadL%>37YhGcoJsJD$RfWmw@uJA}yDUZ-877IE@m8J{#a#NQkLtMZtWzow+) z1o}z?5kI=iy-e8!aOGRd3t9!POXxfUMcN>-f>e^2t0#Ien@KcyeQX%&^<_A9+ zx`4mSF=&YKJseoZ8~(|#7HXNqtqq2KU;BP~o^>~lvPfl}p8j*OrZ6`~8+xqX+3U|w z*O=#pM?!RYS*85o{AQjsUC-=%Fn>?w%BKzoRLKwET#Lm7tFDI?F+!uPog3l|gFVD| z6y9Fz2mk?Bf=N3qzkjlMR1j%iRa#(|14Xr`ol4gA0D8yAPXu!ca80G<`9;0- zmLWv8;ej(CW7gCLG3&N;_$sCikcVD_@JnM?2p$4uM&&%9kjkBNKvjlMYiW_(u;b>J z`4IvqZF{XW6pn&IwaqZDl&y#BTt8E`ttmD4M_Q3K-o(xQ7MIkYF9VO2(DV0{MbEF> z^0Egb_P-HArN8Ium6AF%JR(mgWOrYxJb0SL!f*{b4BJYn3ikpyyjE1GrK_!fnNqky z_kKxNJ)U-gpnF~LZP)Cmn9>o-b-By^+%SJW1&T)na_M^HWck&7C+1S43*Se^+Jyzt z0h*`r3IUJ=$npY}f|$!uHloY~2~fr_JKv+~2{nZ5yV_oSp=2aVL~^8QjZRPw1OWg# zl`r|?zRjC(56MZ;o-s#~XX>lw`={S=n@lhvCI!WLXCINL%V8Z2NOsJ;nN)%${j9*G z?w5y!Y0uep@gQCTtwwo#*y_I}|1EHu&_BWTx#0InqD!;lvLIUwoz>H3h0-)SmZwXm z1whi)Ak1*;2?i@w8vN?n$skH17f#Sir0J>;&M6hHEi1`)0>1f7{g*S7-|{2pLuL)% zkA!w-K}=b~_Z2ku{e`40MDs|l-F5o!MXoPg$n7G5oeu#oqkL^)^h;XED*UyAC%(3) z_Nh~4j9bp1>VAL{S-zmkd7s5RclaOD3#w{jN@%~YG~o(vCz;&p42F>bopXBbi;I+4 zyz*CAY%cuErgaM!hltQaY+Vr_{UF1Q+_sJiQ)SPuWj*1RMAY+46@L#nAS}gNJS)u6C*FDSQf0z6;q2q&EvjUd zl0?82tXn4?8JjEda@Bs0M@BVVGPYDn>>?>f2-&f4WeR6=Gfk10LK<29C%5eRgdf9O z>}7;A=7QLGhqdjoK!DEHbXRw} z6`JMYsn5?rh#By&&XQ+=aNA=Ac+SAd0rD6%T>;s;9(j!8?%Ev2{74FPmp3bd_IjKI zHLK8%K~SOo!d+jddxU96aXuFE++S`7i$qCOH+X7Af=*BLb<*T5+H_?{l!uf>z!l$@ z4Yx0|yoa>rc9ZoBw0NA1W;R#MlWNiMqYj@ADO}CuK+BE>!&IB-M2Dgt^xs`dk5rc| zc&@F`SFWgML3Whs$DBBqVWbFXv%G@=n;)2v$NdF^ch3D zT+=nZb)-O|2=4l+(Lf9Ef7~IXqgS&V!hN{L%e~D6x?W!uQlQ7wd+VL0(9=WM;UF7j(pSFV)Lb(F zzyc((*5DUQZ*Zd-kPx|?Y4g5c!%;Ol*?DxNv;g4`n`t;OH>q2WR*Lkwsd`2a{lFE1 zmoIh$)O&@Ji4WD|4#nc!{8W7NFQn8_^S{L-wKIMU5oFsQGm-OCb$Vp|R{qDK?BL%% z*>`+S18cgAEYqV5;{U;({-dKKjxonHi3xT}%5lw{VAX^Gf*fFD;o&zG) z8C6E*x9(-`Ob2z+n|l9Ng|}-f8^i74Y-ZZZ!+%2;s5h2fBMV_G&~u8R)Ux_n?B22k zu}>L)kbvuhZNgTgWj+@CuF(U>>A(e7Ct{TKe!SDX(fw9aBROXC*g~tj?K==FwDgaE z)ScoBR_m1Vh5Z(7#Fce|BxlWmT`f?Xurk|fMN zL%M&)wi8eFSk(bXMqL}Zd$`xc*DBase5ms9_z%r?c+I)l^s)zlUh2O)75jYCBx$d+ zVtv#m3mYad{JEy00+F zNB?R|+HO zEC;1oQ{w`aCST0sW;FasV-($+Wdiffh2XDf&7>*CuRX`LoDFSWAN6>c_AGTi zb55U&2oG{}kjW+r1=LqZIQ2-$o#WM7mUuPq zFQh<-8}Ho|@gu1q{#7u-6=8v)!tbxJ zXb;%{#^Qns{d}hG&OlG6*va#MVq-G+Qp?N*1;Sw*6S`eD9NYO6{dgd*>AjE+Km^a=MRqfLV^opF_iH7g zz6Z0~NjQ~StTCW*MiY05bl_NQdcBBd3ZpPKv7e}-n!NT8k4#<_9EB}2(p|JY15rli z3G}P9P5+$n97MDp2aOSQZ|7wd)M`*xtW?~1Y`+!!s9W1uJ>)%oHRnEZJP%R@Ol!v@ zUG*oBx|%3Qus_d9mz+1%w%ic^ttwQt?JR6l{?2N=SLr5p7})*^VE_~bHY8Igfxa!h zi3u>K_l1uksbsC?G~|gGd6J`cwel$rHO=X~UKVIt4jL?Dm#%+Ya5L#_pr@aHcaXD2 z)_QJjq zEx1OK2|e`Itm$(oQMMS&??({wA6!VFMy47JSPFo~k+C=5Y=IC%orW#P=*6gjAA5ccSWrF?P9BW(-f&a?_ z&^QQxo&V&Jy&af!_1&Hlhk{y{F)&NHfEj}>FDl4hJm3`>75Igv13T!pQGJz0JV4<6 zi@|VFat5k3yZhC_%Qa(48XXWA_5IgdyaZ+i5;Fb3v$h9{8;D`;xUD=4O?&P(7|7cS z*&XOyz53dU($x@1p(U{nmn4#~ZQBK0EWO{Xp+EstiFz-OR)}^o-`_Vt7(}L3;${C% z^l93$Q`X>|11AdEu7t}AGmJS zTRwtS!n4wbq5zr$PKutk$$&vPdxCNiUgAK1R5?0w2WemBr*90JON3-Kt5|{2!-;_d zLxPsc#0RwR@N&I-q@x3oDOKW@9s2SXaH;jJo zFWM-dj+hqfJ{e=fMOcss(*ER=Mj~N-oql&;Ej&E#>+n+2jx+b z(@0jvdy5)3KE)r$a`t+D>wcrQrF}QubzUfm~@>dP_8FeZz4QY7-TKEfiYz z6RA|olgNS8PHS)XX-Gp*xUzO8ojhq!Q;IjC6)KRuYF!hJ&UXH;>h$mRg}_rIoyyeI zE4#`^?PJnzfyiEwABWaiZ*9LA6NSAOpnk4+Za?U1^LUx8mi` zlOo<=%=P{yxNw<+!;E(#O=fg5t?&oG2wm(SS+saUK$nMe`T!u6{NVVC7DjiWPvk+_ zHImv1TXOCo(&ieO##n1Uw~w6I;QoB_8g-uO_OmVay=yfYLXd;-c0P%A!Vwi|e~w_O zrA!G2zdSf$XQCJ<5lSV40MWQ_n2fq#8X8>UeuQfHe=G8*X2Q5^+~UqUxIPNk$!EPP z63aBw8NF&}c|QU;5W8yKO@nUvR@rCQwowo5><&XP59Kpp&mZuhS;nwGLNdpPUv_*z z1PYpk19&{z2iM~MyAI`%t`i^ZWJy>6*a6E+wh5Wnc&N5Uz^ke zXB0eW-xmVk=N@sK@tkwn?6qi@MrM2d76V{MGqwp!$e~RM7+OmoJ3S$GSlm2u_yC}X#5h;KuuP==1iCiuD4i&0773!b%%;j4IOl)*sDUK3!9sZQv zyH@$4xorER-ZEAJFoxra;AGb0lo=@_)uHs@4?5bDx+DTV*iMgnNdA}KT|o3fq}~Hi zE=xq8pG{8Q9;!o((&UCslNWlO?HAP}irs=`+y|ewOV5f0_|ZOyBuJvAIBCBL(82Fp@{VjEi9nNEs$|8{HXGiy@}Rbp;7974VOj-og#L=U zRuOgv;b(oqJPOMf81IincE9gbU1>vo$P-^xs~Qd}VrERSD3R>@g`{}ZZ*o!$WDc%R z5lM%_&86gRQ7XtIy=^894RPK9zRFVlG0!`@Yf>_`86gf;P;OtjvWg4vL+@-c^*Aap zWKsp^*QOe1vWo{kOMTp9*|bIxAp+WTD;;;8OW&=#puN?{1Zl>ls@*$*c*_Trv8MmQ zg1~6a%R5dgxCeR3KqLllI(^|sZ_8Cc$c8$=P`j7p>^4c2tO@JACq;ljD!>)2 zw?~OL1byGE@>=|==RCDIpbZV}GO8?o*2>~t&oAJR{~QWs&g)zn>`gVs5N$RP zLoC_opYwJH4D*9#_}jcLt_+r2ny+l8>rbpe5~Xm2D_qs+V0rP*snI=fP}|%0ptty! z4q+|O@W*kot*1wcE6bxdEzO@JmqUD20mA8j|A#Gi)fuH!W?S2LbTBvX`-hzbE1&F-&2l$Ad=qhnau!IVq)RP+ zWs$kzpDzBL*5d&(%5_GXG9V>fw7^DErZtm8{s5wI32sT2moa>;*ujXW(B=6S25 z^^kped_#xh-Cz5;EC|&e+Gen2w%q%M^UUQrqSmL;(ldqx#cSXkSw1ExWsIjti)8dM zWJv@D8g0jQcZPCe6YCyau?O5RQ>Vp1cx9xAc25MqlcbO`&b*+;?B5k~pO!KAc^oIWe=|@IU!I%`=-m=S}G%PtL47VS70cVUC`&r#2QQ zQ3$$DYkxRu4`2H7gu#u0Akt|De@#A?p{d&kJI|ZeHCYi2rSL&$GXqRyuY-hrAR4Bm zKkxnbxzP7?Fy|*X&l~hf)4>xA;>mL9MeGe>uh;-zXQoNSU6H!lplNYu5+D?QX1p@; zD91L-PuwpVhJrTo_`2+KekfFZeR9WtdJ_~JydePx`VgdYLO94+{xQ{3;VjjFQh3T5 zwm~@D+Io4no_w6*B{&&NWpI|LqEPasc!2&^wH6T!2anged&9^5_`-=A9w<*VJ~esu z@4i|lVt;F50$3ik@hmp|D1z)?eH(YVb4F||;1=OcBq4U!s|~iVEa_+rY)@ObEaX?! zFwy9d@&r#3{oHtedv63tLO2%3ScxI5?>f59u?#$OKG71w%F9kjJN4=g0$saNeuC81Oh*L%+Xr=5Bfi>BctA z8@|~~x?{7xTqC!AhvDtC?%VbJON^`SYqbG#(paJ|#>;G_Xr+Iw7Cqd5|)`03?-|y zSzdwEqzG_`MY zu-$;=Vmg|Pw*w44cV%7?k{hMA?v6>J%WJ~Fz_M<7{T%{!D2F$Fjfk>e_6S<&;uPyB z;jBmZJ+-ae1c^*}f z&y!_<)N{a%TUlJsBYt=7P>Q@#%4mMr_8Wfd@6XnX>TmCY2?XPr-y;2HcpyX)Q4ZOS z4*f14SbYqk!HIaS^JpH_Yp$UEz+WJ8SX>YKAXLcBL*2`x;?}=K6HrH0I!*{%_* zOTZq7IaPPYIf^z1OY2W7m<}ef>_8%9!^>;E`|-78N&+kr&&3}Y-%|ziB%|K1`Bmus z(wdy}=I!wR5fl5usRf}YR;s_=5)`}r$_qA7pBDFj5Mwhr>Mbb)>`S1+^cnk{J*Yb6 zMKI}&$grFX903i;;N<+S@8c{ovVY5OGO~!Ik@US*d*4sbz#`u}3=D+DK&tLV)g7R@C%vesNP&s#2$*W1Y>)*6#5^gRQeD9DH{a&EROl)B-X|DbgPr2TTJSOS2beZPe&_Nhq}hhgKQ1Zhvk_(cQn*x8M|HFH zA|%c#L)@VMZpxe)>2Ue(-0ecNfI_79C!cQmhI0{i0H54P3-p?CRpv}u!q_A}W!5WE z-#@vQYwaFys27{EJ7~8UEYs6VXw%QPzs4+xw9N_kCDl-EhZdglq@mv5_`L%r3=%?B zc-Ok?znYv0MXi=>c?f##%+%?Wn&aRO|F6&SJgJXisw*{_{Wxd%PDH?+EL@ z-?k%cKcyYFH{{C-EB~}J$}`>JoUMz~^hVY4{tx22<;E{!*MG(#S|*CHf+zuPGMR)+ zpX?$op#u4AY1z`A;~m$mG}q~s&{l_ccGRt$p-rAF?{4ZJnvv!4v%A-Z7fy$lHy6e; zOI*YrdkxB1*!Q+ol`2oc&<%;ytPA`z&oZ(?ijEQBS~MQW|Y#xv~%y~=Z7 zw|N1>or)snSKHb^=2trmPJhbXsnIn=5+*E9bHu1*7p2G5O|6{(Jtf_=msAag0=>a=kx^HgMb&v)Y-8zt+p=*R8xdo5dmurXrqr|v!>ul8hhiRwKguU zu+i;?Q$h7^c1BUf-ZW-@lKc?4PCjAC3OalC&-iEHSXpOz9@$_u?DEClhmhL?E6j1LZ2Z-tPP++v_ZHsQu)?tI{wA z=k-(^oNPN|s5uc#LBcIUeK1PJgf2q>uwOmURIGWLI_P4FaOcPq6<3+=qWyHuXVte{ zF&Hj64e8-_Q7pDx9*EwnHhElLZD^hdug0Vz`-RlPP5f0N|G|u7O3L{9G(a6m#o97e zvt9k$RnWC8Ee*FhQ4wP5A@@qUdB6S6rd;vs=HQ~3yD&B{17I!>wdAeX!TV^!Ide&Y zQa}8*Klg>_-?dtc&T{5<7z!|3L;B#)-p@rslNEh&IPc?i>V?Xtz4N=k*-h$VQrsIa zZd?`eH7y#-$@{$yX!{{14jnobCn@jk7_53754v@(+ZmuAlL zz@{R-CRq+YPWIrjv$|otoZ2gM&53yUvux+eJUtXmhfiDGNaGuvf9Lg{FylYd&+6H# z!Y%JVOT!JTnlNUM?5+IWbWE4rYAl;5d*6TT!NK)rKj7&xigXl7RmJM==$(aIPWRKv zx>oXDLw-SY;b+}ZQlW+@M6DDE`f)|O+SRMQsC@9IDv0jBr4(14aKe1|N;VSn3LI7f z-C7`IbHBZeX3p|CP5-yr}&F6 zeTV}FPnIlK)vlY3va&3YksZ!>k78hiv1rn*5(adJ`k!@)effC>vv}<|E#DOu9R3(zK4_&V*X>`8Q72W%1(=7WNT$TIIpfW(WV59l9aO zy6k}r*UVqbkFCEM)eXo^N7rnDPmr7P#E&huvcIW}sTv)*2bD!DDCEA?1yhsZ&hv}J za<@z52_E1(AfF^ggk(J4U8*NDnFvnA89TzSe@J2Y7)HBU_6gx{@dGd}`X-XmYkQKo z`9w^rjM25gIz7P81(%57VWN!1@}a}F`2;#x^1)e06~%NhL8t4n?~Oev8vLfu8Exo; zv4e?VoXR`cRSeVZHcVf^MsY(i)l#g(3?6s0DQG7@Zd5+oq|+^sU%z^<2m~j&{ln7B zfqWo|1&>3Ftk~0w(u{Pu)tVDM0ZbM(*f5U%Pn0F`0SRJ%wh@kNrCjT@>RqKY zrpH&iT2eyw_6T^VI@ne+?pOE(Gx3YQN~G1cL{8poB}@P2;%*wPHDubG%}Q2!-wDZ6>)2dme6=<*8W53bT555 zO$)Z$xNCN`lY^`WG4k@r)`0nXvc*+lp5E!)Rk`EeZ-Uq;E~CFL?tjoE1ZH%BZ74|z zNS#ue-r4M^BdhZK+=>JI+(LBl#@KVU{ z_3QZ4!&4jCuntj1?<{)rxws5#(sl;_^c<%^gi6BwhTVIf->LP^+T0KR{14Pjvo@46 ziRsad%bp!63*!#Ce%CzV?ATj%EcU{J3Q`XWFZTqVTkgv&W1jP~Sv@8C$Qp2>h4U1O z^>wXbH32VdikOHY<_h}y!Vr#jj8duo(|Q8BryLXh5N`4Y~=iB-`K z#*ZP093cHxiZ1h>UXdsFdw^HA)`5Pmjr&J^?;qgauLM=h^#cXmIUR0|R4t72yaKWF zR^~6+tcI$nq44}b$Q+KajYsV7_4?kv_fm%g=R4wshEHaGXVgI6P|Z<%iu-Y@(9B() zB{4)(r+U566P*spA&yGB4&!Ci0z*^heg{?e^6TFyU=a$iOCpLx<$zZ!9yaOGM7rE0 zk3bL+qXfGsP*HMfJpmFSKw3oc9Bnh;3r(8Ou#-hG);b$1jb8Uf=3;~cmk6oOd8qpdz(q)a!+eg3VWI^d-%Qxf^*atOIQl z*ViEi;83{zjk#Lc#uOl|sLf78M#Sm|k*j?qytw)uS8aJZ_us%NA``nN<&B1kwz64n zTnHTpgH1Ve5%ZxXJtb_JYIwQrj?zKpKzSH$^J!4kN}E^{LYZFwwLir!?#h(8XXEU@JSua`LQDWiHDGr8K1 z*Bf;${47p{;!rvwgKnD}frukbCQ{s)kdl-PJQ0J$iAT)8PvVIENxo;laV>H7Q3Q)VXcL~be+E90A4t-Sj3;^sJ<8rc<~h8jR@-S-bMH1j z_ajm|wzsd|>2246n^~A0KHEM0XNFC?z`t+mkoErAc3+%8bk{LLr8em*ge^YY`p{+&_6S+7JVeQSEk zx^LKJSix@zqzR4;&sg^tL|p=My;V^D^S739tB|)`7HJ>Vnqh5Np?K+ueQoYU7*sk9oV@#a)dR@BwUW&mPBGx27HboRmze)>H?5vf-UZJ=Hg0gkZ)7 zEp1*EQjoWrI#tjp@B#dQGR*#WTfHjwrfm*R3AR{JhVf{?*6?@*2l#yeNz&`~Wvl;Q z`+RSiII7GiN30z7{V+^{J_VG59LgxkDTF%A;;e=^O-O{-o%9Tw{7Lq|nDWwpml5vj zksROfh^(O4952oaT83i8*PM60@5n90mexP_n2P{x8GCaRqiy5o| z(E5RG%+Bm{>LJ)Kv_b9=N7Hgq(Ga{P?Tli~KXqjozhI&qN7n<@WS+Rwm-#xM76F7*+bXY}S3y(maZbh7>5+h3f@mzF8Z)b& zSucsX1oUqKk6R30>a5>ow83=BU9Wp=#v9(zvrhOd@5cA)S)RU3d(#U~Z@b#;37!HA zb3h7x$IkeXjAh9n`*KU^e_3~k6IhVcLyBArScFG;hey<<_2MB7Iwt@q+IYULngL$X z4jHiY34st&;^P$}S}Ke#q@}eyHd7p z>6Dgkkd8yAbax9#cbBwuN;d-1CDJ8v=p5Z@+;XUs>vG!VPukE97H_*eQ zn$~b+h(1kE>irbJ*|hv2^Y|@!+SKi2bY1d8IyBADzT`I>ZqJl{N@Z46(Qj^!o=gX5 zyciEI%zp>SV%xsqo+K8D3@YZYm14L=VNo*B`LDs3>WzZLG^T#vWgT z?Xj+j+Y6b4AXP1PXQtyGX$poiAmFigja(A}gT&7Wv$*4H`x-0%gp@l*!q&#fT&e*2 z!8hU6m3Ci7>a2$ODsWV48Ipf}Tl;m(q$l_%L=sM|E;|WSMuzDlC5$fy;~{T|V?M_n zj}Br+{+fzZ5KC222_HE8wf~2O?c-}C!AwlPl1PQ?HN;?2zAjlL`|8l zW=Q-yC7bB>jK-;qnQHw|EBw#I7pyFAw}l!G3_zhkJMu@V!B=5ykWfA9tka-67m|z87}=ka9}TR_rAS{9MJkMUdiN1@R8hHvgMbAC6*zBT4c|uOB#- zGck$Ak1W<(6_4K~$iVu2Wt=*HkYlMJzm$A0Qj3}!ZzA0e-6N$;`|B{LdWIl~UT4?; zxJN<^lN%ssnrW$mEAF=fg@3YbKqONrMoOt;hgGVqkl%94UpIzYY4x}f4)p!=c(IT9 zPIv#*q|EJ2W{=@RhbGyxtaW*F47x$EVTHTj zEJ&f_&L=xnZmfjz6#ic7;~Ri!%PpT6Uv6YF*bhK!QVTQ4MvEDaKUr|wL(OJv)m3@1 zY_J&ccmD-;2Bp)()Qi@!OQS~1Of|sxT?aBu*8z2~}BR-8kwN$)1N{_E)TN zC@q|3Nt;@=R7G3)@lr*~5W>YolVr-#*lBNfmv?@5CGsFK?4^Y0!@pafizta7RrC~V zZsvUMl{rL~y})|@yFY2v#5t(9c(HkjI~#!y$zT>@oj!ws8&Q07dwxIX_Yb$GUc|=I zUvZ@|JPB=2>t}WL)&Nt@{tY$pG{$1^8kwLYgTA{3NiAYZOXv-3V8uYe1t+R*XhOBm zvuZ)L;fQzOZcBcojA0{M2aGk=`_-Y6GI-sr)og%Y80>*;U~dv1P{RVMNDd#P$O z2LUW=Z5gl!6jwXLlZO!^6*LR-woBm!>m67@N|AFvd9svNwzew<2;~TzSF_gs7@E`O zN-S@J!ll56WveM>YKcvOxNZ?2Y52W_UN0ppy%zqc*@ZOaejuzbo}0?ME$8rL-4Db(WE# z#`U@fB0FHlSg|`z0ytC-O!gbz{lcXSCzgqRe4bjD<#$-4B+ZcM+NgHqM|og+HQBV4 zVp3;ETvUyuOtjNf}7>)31~oMwtJ%jSOUB&*{7=3LR%K*5Gk`)|glWNI2phmK}b zvQKe?ZVR+no|BNY@3y$|W{!H{3K?4>MyLV3l{F6Y)_9 z%ZdN>kN-!Fd--yXVk9Rnp|vVbvhs*L@g0OY_ZBd&86MU)&v695Gq0*H{uydCyI6Bc zZ6YqA1`(#rmfCP!_`!&BmV~w1|9oFZQlK)ZyTbPxUiM#ENoi{5!$=lDJs}dk{Fe(j zFKP04IEQ7+OL>GvliPsmu(8V>U48x=sSKgLG`7n3A0~Y`!qQJOxMC(<(}5C&HU%%> zk+qY31x$CLX1x4Kt*Y6c=Zs_v^$@MJv}wh|7@KhG_LS;x{xV$L?#$zgpSjbZIGfC* z%|UtIWC-~bKu-;f@Y&$~3}rK}^*v=wH%l^l6UdE~W6l^aHrNsB(hK@eWGOG9F*jFn zTyoogHLy>#zW*6)D0|r1T}F=zRW?bBS!vH4wIcYeul4H{HwY&l4*L5sermotFAPXB zz;0TsFC-JS_!4;p#8g+nxRdTnYxxm~YdcM`47Mt{yJP+@RR=KEbGrlHi(4NIIMEG5 z^yQJ)f2pMN1d^#z(jT~M8PR)9K*ppBbK?_(VFKb~KE)n=1mlJ}N@X{h`j1-E8xmN~^J@aN z#^zd_4y@MQ_JbO=NSke_X_rl1{IhC*{NxRkt_m|cWMXvJ3F}{pvjY9gL8H(jIT03M zm5GybaA6GbS>Oby&Yc$625ReFY|{I_XqdWY?rkCzGu4=^FeY=&*48&C4_d@Wmx=1` z9hg7wVva!xhg`59nLfo)e@;Vx#3M8rOpmaOa@RIL50`k_mL(I8GgZhuHrbEHzsf%D zhSfU^e7N+s@1@+HNjdlP_C{FsJ-~>r$BO;^HerD?*zd6}bnyJ+E+H|jCI$0snxt53 zS}BCj>W(*V24gHWu=D94t6W!eK#Xb(tySSVyl2*>CYD{9I9P4-DUFciv@9-cGW6xgRMJ1v4$$4I1GgQ$1)^<)7MThFRdcEV#9( zrBwzvl<}0)-(%guKcVbu{EZ|9VpVYzLL}zP;i(S4sv76FKUbR zcpdDF8ToJIjtTJ3*)_=1IjV#^r0GqzEKRL0LaWgP^PH%R7W1ZNQPrA@^Objq@!*oROo8)1>EYGll8$(B`JZ5z-ma|1 z?I8QTdv0L@@V!U$R>sMhe*2{|EAakPmx;s)US%%{a26pIH|o3e-90q)QGP&pel9MCgh1`!G!=xOsf$V+0U zWedqY$LjP3W=(a*AI0RW!$1trV%3*dsXpxDZ3(5ZwFr68>t!+2g5qSTbFdayW`^fI zSRb(-3X0=3kH2>bx@1ISktGsAe%h@eagGtUC}-?CcEC|A@_6V&cyNXh zhC|g2??bDy$+K8*|Ao$~q#m0u_W-_`Uo>vKTi@r1DSe;e4Kc>)pN9N4hqkOZ&G4kl zLhpky9nA&Z;S`~UxUm2S%t2RY8Z+ti&P0k)2(`w5XQWu)>ONxD^?0 zEiJy~0R`%j#nhR0#uL9)xZJU3;!zxOtypv$#nfJ!(;`bt+8g`pAPP%qK zPS$TTYf(vObe`@}uB|%Wzvaa0oY47vOQe3iU(Y;=k$_Yub2IecqE#E2Gf5;9YR@2u z(WObH4kP@elnN`V?~3lEtB#f*nl*_NbS~U03iDXIc^Y zIP71C4h&e_K>J~a%>Z-$?^rG{?MHVwas}&z;RO_#UiRF9*IJi^c*D;31Xpwehd4dZ zFx`89P+*QT`2H`mnSwmZ`Im5>)k#Gc2Y(>;8%?W%=(o|DXb}!I8spsrJvZO~FZ3N7 z|0ndt^m@y>QguiETlSM~8z>);Ss5Y98ExGmKrT0#ge1pZwn9Hs+1SJOnt2Z2E+ZIc zw#)@>NQ@OtQvwsM`=u@_ydIF8L3!}q+?n`lb@CnSA+DWsGw$ZaK(UkMSC7ME+r)>V z_@8i@>b>Hx1i?k?eUPm-XN?+^PtkHO!6;EkHOW!NF`)DW?Tb48evynf=KDo^cu)?G ztoE}o<-Mh)gvbnC+QXxw9xTC)si>qfbE@ZsZ&d=dQ<)Z9W4g}v%-OMyt7C)lnp;a2 z&L@L(@p3-vHa&LHIU^dk78=`^WLDl$I&HCO%HoonHsL28fY21eaWfG71wK=fS7WC! ziihtlG^%BB0Y`hY^0%GGn!fQvT5n6C)6G`23Oo9u--gxrllF?LG>{A8!e|E;TP9r| zTNF}P#pO=BFUn_QEW@oxFzo~y8TT2pr6##KKp1?sgWs?3GjR?zR@m=acN^Yh{?(om z;uoQ{QWbjS6QRVF1%a`4a{~Vzug)PI>nLF957u?I-X3{H5su&MfN`a?S$A$L4NeTi zvbc&xv_sdYYej_@81nxS(bm&^3gp+1`3X2LB}HcidZ9k*oTdgGUk^XLRt|8_2z^@6 zFG5cUIU~QWOnW}U)zloHY94Oq_xtYB|JGxp_uJB^^!t^i=T)+5D;e%5kvKj~UVF79 zs_37WQ+&im;>Ji@eQ>I3g{9^GYH0ft7zv`qOiour24mMCtpk__KSok@RSY&2duYD? zXakr)BWM1=*$t}LB=kg6lo^iOT?R=WpOVm!aCp=pK@K(3nLN+%GavNm5O>p6-hW|` zQby(Y9t%g?06(t5a?VRD>$}X(5jO#>OL2bt=*hCqk$FVT&Kgj{l69%|TbJ zd+zKM{%m5Vb#~y_QI2L1kHL}MW2rHC&1-p_!>h|p>g2GTgATV(@aOFI)$C{JTN8eW zu{(ztRWf-6(Vs*lI_9yQQ>PF*aA0hrVO4A-7k$YKg9N>!_%z8=lBm#iT;%XqZX8#u zMco3#0GXgCh`wh&-7vp)c-o^y!bAdHUekugY`E{!rOy5q(&Q3(+iQ`?ZmTDZ?PT(- zO6wf+EY!n9h5{wV5BU_QAUvkD=X!XNPCwWljd46*4*aN~r5cCzH3mkZ=usVEb7!5oxRkBxWR zw5_?4@L+tr3$Cx}bez>p344E^j1IUt4CE=4%D`{lE(7{Yo$1H>)#Gg=7&%|Ge*fsaoa@k1gRW$}2C~A>YjzI&yq0=@@0^Gx+bK;gVJ9CXz;o9v#+=(R zlNK&jtjB_JMf+5cuyw};tv{N5+nl4;LLNw}QSxQmhUtO5#Z2epiboV#^&^q3cak^7 z@2^1zU9<5Fe=_rr-p3Kasge=9%P}@N5s~bSnAM+&oDoD)1#E3it--MCoB^8A#Cni> zaGDrBZ0nNmd=Y&zLNwrhd4fcxGO*zc^A?tPpFd|xDD96u`x~8m$mjK7yv{JEFxv+= z2=_|_$e>Y{{%}e~<>0HoW3XJl6rE+m-UQu4GV1Jl@6`1D;nOA(_Kx?3L}EsC;MUVs zAZi!IdpO$)YID4y47XgH_(PmV0SS9&++I}{j^KU!ZXbL@agI163yzVQCuQ6c4r(Ni zKWg|}OsdEMdGzJ!`l4FO{;CoK5c|c4y>iw@G3CD&E#)j-2n;aIq64UCvo2XV=fTnT z+qe5NI9XUCVl71x+f{HxygV(t7&L78fO@hF}x?zh-W>MsC{hvoFz?IR0+>1R@NhcEb;onjP z+LsbK%n8u8d*V*j7~NKk4UUxf0iJ1f13mm+QD+4yNL=Jq`H<_-u zA!|olb#%-_@3L5>sMyP)N)hpwuPO$I@?LLyXo7~z+!_{GB|0GeY-O~U|DTU$^rk&L z)y(W1I*a#i8`HwI6bHY@v@3#D|d%&v*^v=Mvp)SHm7fvv8M36`(I8CO}Dtr(5N zB0#T->YJ1ofoj2l%L_)tq(PcK3_ z>^Z}c1Rs+n#nG)Q8kxXy0 zeQ^3BqgnSOT2|nl$&}svn&N#0ZN6 z8pAhYsl|(TTG_`V4(I$`gI91N@ALksL6p?jfVH~6N7^V1f|)Utv|`xtyuX~e`$kxW zYU*X;>zl#zW>)#`6O+hDL$>#2=9xn~nVo-v-!n+RN**7OqLq*wXBpm}Ro2k^R$7WI z@YrFfWQf|v^y>V-e#8X}%l&OR7)$hOlkioBS|P?VdD`FA1`8&*(*n+zkp7;Be|;bn zAX?Yh#b}gT1w>v3sV?lTNf)`uO;e=}om^buM`yPN`lJhm+oF~gu0rG_s!~uLez2J{ zU)O-O{xwh9Gn6sw=vok!^18J#*_jGuYJQ9{W*Q{azDtwZo||J`E}+_}nad85=1)p1 zfqsxKV1K!Im9_op8arp9EgU!@_?s+IY?gbSvE(;fVBp_o{bfZ5uE*B7vjv>8rkqD6 z6=b$~+a4vxOEC4|D+K_8vN0~M-} z{&x?a$Gmb>Zy6$PTA7#8x{|UD_1w=3W@g`MK_S4b%4k<<=*~YfdA?jwRH01i6>X+I z@Qb5r8BIJHJEXhYY5724QC_A?-`gD{;9%k8E})LsV!zv5L&_9gCeYX0w(Q2W?I)Vn zY>oMR2bUUL?`BsT)6uFunn-8plsB_9wA$AYd*6z*z@CPh=SrqjF8bmd4bLi9wFc;3$eXzYC3oe&Tr7H zMwLvJJhB(S7SJeyWk{EC!+sRMxQM* zM|<%R#U#RMT{Dub(YdanD~Wx}Ro4D)i2U&##rCWhEyIa%+RoE96x~vW%coDyaDoN9dkJI|`;%0RC&`t+}&-{q}^m@&iB12Yr~Zs?j*o*?6chCNNn zsi8gQrz~vkn&-5Y!~{`mmzV|TEBi}!lxy6y^7+Qg1JNlK;5kJkPnL72{TYFql8lhA z%N0Ck?cm#;DFf0-q>(k1b-qpK#trQ_>yemHrjNXiIjsFeV>7WZIZ)Bap`@1>I$&tb zk+lAJ#g*|XmSsIRDq_C=3z(T%^SH+AdnaCGOgwx+0OlIGXn2#Z>C;!}@my0(Uvm{v#V z%DOZ}>k~IpRH4XTS-7;hGtDRUIJXwwY>Kr?Qw;Bem83+1R~A;suoPypO?#bjIlKXD zD(8fla@&G>mYCB)T07#7*vo$RdwHJ)?D7nq{sTtZY8Oeg5>?dDjl)X^c?!oa#U+;F zny*MLmU$UHMj57l)Gk3agUANdE$O2YGPBB5VUcvHIym1yw%(Gw#li{G5qvND&hRr~ zRvp+48$NUb@&^70%z{ft`r@|*Or)gx>`-(}OsXV`<78x9( z?(|Ob{qG`%-xP&04ui5|sz{i#Ji4HfU4&as#72wa`8*HFP%#{Jttriwm>3$L2EB z4i>G75O<@v#k$1BlQUSvU% zxC$w9VJ{}cP&Z^UP5lwV`EscKREmgZ_zNn)-ts)@n?Sl|+vON@bLT2IjAIEl4F?>y zZE;Y6VSp$1TPF&}weHj_;64qsVLJ)j_@WFfR?Zu8)!UVM{_ z9;a!r2LQ$sj?GwF@o!N2q^PDLuO1WeZ+1vyC^P=v5PqtzZb)qzqI8Cu&xS>pv&{UA zH4W>C19Q5K(>;MsLugcZFm}D(Zr@1ycQgw7dCv4NvNek~1_Rspgp%V#X-_ z@7Ndqn?JZDi>btWdzTE!K8lF@Q2KKu1xrPmQ4z$|1 zx-*9~#yLG@VqxX2C2zb>-_gSE!pXyd$%KY*8Dvt{0cM|Tl`H3OwC~d{Y!L(n_E!$2 z(pYdM^)GyGv*m+68ec?yjip8w$EBJkY$GMGN-w*e8AFjix6i*b%1juGHn3S|}HbLc^@kyKe#6y?SCL2lkc!CGPc%9Jd^JnhcojVVX z8MPPHup9Ac-J^l$skz~A*ykB3;%$1>>AY& zf9TdhA?e=_Sy8=3ldgq1H5T#yD%9DmAVW}0)CHx;e{97NfcF*4HIP`?#1R};51JqK zhR$w(0!M51yl-IgaC?&-=sBc9DK3rWLb|R<9|eql94opjna#$}UYp|ux_jHWKOe$` ztNEq$R=kpvD-Cr^Bcp%W2o}5K^vs`k!xyq^t<4(N)#UOE^SffSf#A43M`qs`q&@jj zn}9ABM4A!WCHcn;5LzFSw$G+!8irat6y~I1QA$UV#LBbED$JkMUqL4}^|FyeO+_67 z2%&m6EDH3XH*ES36E>=eFL}MsaP^z0?!*mVY}G3ve4#G0@_`U^71cFwwYgCe%yT=`val{>E9o>SIYh=7~1~K#`w^ z=}O`R)_j|xE1_o#;Ad9INcLJOuD_Kk~UzmDM0b`0_%-2fFiIP0YYq~J(7F%*U zw%q#e@L&uuJI5%DBjw|(>M~T^z$0H-!Kx4Jf< z%DT?NESRFR(!n1o!ob%Ug^1qlBDxU;iQ7wCq-GifIAPSya{`epa_G`DnSYy|wsb+G zI#3LwkKj1w&*EvBe~Tu){^%*;errIa@v^^~flZo1IrF&g;S=s`^t-QvNgc|80fLd% z+3%%Qyy(M9r^wSwv7_Dtr2E~D!vyDD?eULDx+X7i1T>gxoXq|KpKri00-gnT?CUJl zWe8)M%1}kb+X>a+4H)0NN9p;;sWfWQz3sSB2FrA{{$!1qlSgx`*uR%cx+3bd4AuVM zO|;)kg@kJc?Jp6tkm#>pX&5MuLoQcxh~C$NH>(?gG2mZ;34 z^<(BTv-W*Q@TDoB{$t19prb3UDO2U0a@4q^2L5+kN-lbS>&N>o^~*Vdd1>-nQ3*S; zcIS80zUT(WxRA$%) zFz$K993JB*oe5hV^nme%yU1M#0+q^{WSZq-*MFZi0mC=KPZ5T3h~5D#bXulQ<&H-km2slxEf>xzf!X_9 zxEXJzzThK;_~69TrR4<@AI()uT}wsp(^*iSQaw#<5swqf_h&A!b-3TtXZzh#b=wDq z|KmSubwI+*NHf&`b_B;@B)<3z?XX)- zUJYi(%*bFJw9Ta`Wr?so>jLBItY2FO)$#OIK9B`1>CJz~pvCUjuh;4bBeeb(AIb3b z13o)vp9fWfr~X%C6wt-H?jC-g^3p<&C*qDCYxVZxCFmaL)ud9WS?p^dN>pLR zPh)JY$#!ABmu(kVuk2e7mTEkB1$vEj4@krzNPxjzKN=nWJ?>)Pke`!SN{Ip=)Inb3 zj;knwIgJ0u(P^AES2Hl)o4`%4^J{$1QhD*88;wNpR5D(PuL}~bGs<_l`lfS7bDc9* z&@k(MRpOC&uaCVsB7WtzTpeh@Ar12@+vU~NCcf~WzqmCg>MQ4cn=!@C?6BnYv`A~DY zCJ&l2w)_5xhqo~fBTZPHK<`aIC1jJg*|hibUAxG_$Iv)6jMlzi;2|_??D1)&H#cud zAzBh*`Q_xHfzAaSMdgs!LK2?({fA3x>@;uXZRmeKt zvd|&GJf9KxiX_sB?8``;-;nnWNN^MEo5qZg>Ej!zp~#!wT`)&&#;B;vwvcwbgh*LD zjz1-+_M&h5v_Ou^#~jt4_LJ2$?PD)o=`iYii!ToFgijN^s!pd)e@j2-s={t6?$K6r z4Y?n?GJ2is#__AL1Lq4v`1~J!EI*c!Ty|)IN$~vZ;3cnKM-V|+Q0JCPfSEl&7mao; z!kz546qtYmvlVbt6AFy6DndmlAA zdzUgfF->E`aQ||cMVeBy1KXSl!Plr}Uni}IFz0vEZZqT|j?my?EmcRFqV2u7 z*%-@jvp?Z)-X3<86#HJcI&48Mr{zGpXsa7tS^{`LN>upEnLh6?`}AUPp06GO>#xA# z?{6!nrMbV(aIG7&2-UeOXhoSis6U>r06Tz+*ZL-}Ob_1QPQt$}*M= zCvG`0G~4tV$;C@O$L7e2`s&gIfX;X^Bd0VSjeeM|&gii#aX&$%VZ0a@Q#>Y)n%Flrp!!`tL}|M^C&rQe{ihRhnU;~wv^5uFj>QWcq-7;e{J+OYno1h| zB*W(HEFVK?QJ31x;L!FM-bZ7Iq#SnCVAE+8^(2h@SZz;DvVW3M{!g>{VD+{cH-UD_ z(ehhe7DBnc_INU9gvCDbKj#5*4+RG9Doc~@12*_vS^z|m3st^#Z?)adax-^rRkQJY z?C)%M{7QVjS&$b`I|0L(W?w6;XGv*oWHemW*2S$CKg*xAt3;ohc8PW@5t@az(0 z4ed-dkukfDt7~&TVRW6o7*tFs9ZJ7d^L~o{sX2lhs`Js~=}<+?C`{V;XxS)9tmT!@ z3&?isTy?n~HwD_;b1T_qKA!_}>hW753ll7OBE-KV{7jEB<-JoMXlQhCCr z%A}2Ux^imLZ~+F1S1UM5oaO3xuf>}+<#6fW=FxtRg^6Kz+5S}0SC)#F|0F}?^2&z- zH+-lYkD%$Yg&LeTH%lZ|Qv8jf7W{)lxZI7<#y4)!O#9j&FE3m=+~e5~+f|cavvuZ% zS1q@v&t5mK$NX~8lxIv;>o3}F_Cy_lV{U&Z5^hc7n$z;DpHr_?1_a1C4dkoT|CBwfpd(3E{PCqjni!Ka25uzPZx`*`Kci*HVjE!E1VtHOeWP^#elRff z^d+|5{Y8NS5|^eq*9t+0Dd+I^SuFFfBRmq?@bDQ%z<^Q{rHuD};hLOUkz{FiV;FJS zIm2H@$?H1oN$!|_FEl=w_ZI8RM>ZOW{FKn!U)Ed>;~~NP&WrTc_;~NfukF_^sOiz1 zeh?BE7JmV=Kmzy1} z{CP>l9ZH>Ig=rgbZJj*h99$KUj9ZSQ&fn|KfREG(&5e zBg5DR`mobzXN>QXUr)6sa~7V*-%y?|8QnXScBQw%;|M36Az1OF)Dldp2l*D?Tgqb% zc*fZ6GKG`2q}`5?;UlBPcS$Z-{nUB6%k_j=h_||S6#^#_70S|*(7tH4*2DP+7?P0$ z`HG1s-d=af4>vj3yi<<|6P4uS@oWg~FYR%-4&R;3B6#(XDwZ@RLROK5z?h*i`gb26 z4-D>}TR^f<#os!;<(rL%)lPB~ugQ6lHLK&T7$iotzZ)q%tG3-GZVOiyd!zbM*D&h-i)-(l90RH@|Pid-SejS3h#V(UTKf@5WR@Mn9TZ=fqx>S zQ}xe?zejQP%`i3;t@Xb%=QsbBXIWX;`eXd=xt4MH4HlznI(el&1T!=k+Qtp{$7nEtpTzXo^TnL*BwJ}3BMiIz zH|pni^71a!2RTqm-h<+{8I#EwtLE5*FR;ik9+BzdzG*_>zh8jna3vMx5lty>9tgoU zcG&xBEheuMtDVq_in8s(q$AI}_&a*&sZi@s@3Xr-9-Wqm8s`;pXBT*|T+0%aDq0c? z)-A8$hahQ)(sF3~R8=*@5&J0<9QXB_QwR#3Wd|@i_{)CfxHc;3>x_Xvu+?t-7P9a; zq+Nv349~RB7T_>t)u^E21zukfMW>{9Qf>9ejA+}cbtZb8ZZ=wDJt6dBBHsaX9LP@q zMbKOM=t(L+Bhd1XE?DT0f9O$1Ls>)l*X+@%oJMDg7@K(Iw1STdiU+T2NSlb>iVtc+ z5o>bnZG^n6=CW!|)I+t-Cob-Dk})Jdwy znRyJ(&wa=ujUi4fz|b%CsuCEx!~erwXGYusN)QqMF+t|%lM4mLopaa6fAqGC6JM9V z^c!&*$IC6%dC+BOjEQ3-j_JQ$+ntThrI+D?ry~wK{a1+xIw)S^kC)_4WQksP(Ylre zqbdr5vlTrOk>~H!uz%8#S7VxeD%F*aD*=B(cF7FQS9arI52<W-LwjI--ICRqrET->mPaLJTb3qwML7Q&PA|kubWVUH#UpT zYZW;IgVuR4einLPK8+%NQZC2Gd!6k)LK6n68a*L*g%}#*M%rJWQ5!$lt=*}i0WeBz zb@=bO%^o>LI&+7E93^v~WUE5dX2}zZCx2dg3fY|E+Z2tbqL)RZ{V>wUhrb?$@ZmzA zbsVCWPdh${z+Lq{x9F=TW#2hpjN@ic*H>rfJaX)Bzvg`|_|;5w{)=X2&n1}Y+p4v$mIz=@Um>t6E9)b-&U}KZv3fA^ zk+kQxHC%IWDJ!>kPfJ_6`6828~+SO$SkXSOb>( z#0D-Z0vJC{A7aK=zHwuJf{1?^il`mJH|)WRsv|*&yz5m>-aIGBS@~7XEhVG=gZ{%W zGtT`uIQaRW`q*8JT^X;&SjzIgImTpkx*M<)w1zssKK>|Ul$ieShLY57tz|Y$Cb5pL6kr3pM~su^VX1Bf^X*E zhEdD5RZnjD)Q#iCv_IZ?tOI1ZF-^o~qLxp&+E?7jrJ|<&U4N-<9}#v>L$pLgA$K=S zWo;JMYqRX>`JMm$_43llLZMUK>$LbD{!G`h>{^AUw(O=6+$xffilR2nzJ_O|5_|Ll z=#ej3WO1p*ZNChm;vn)k7A5pn2>eJNDv#$UZMMS;wtb&oYWn&*p7yhj@-M&pKN9hD zgvyXNXyP7b_Y_yz9b#5hF^!m6>tpi6n$xO^B9#gD9MQ!rDGT97{XMK#IZeXP7T5{2 zpACZa+Z0@i?vth*B{ZKxZ_z8BuS&naCj_u4$=YzW}yD9>M8b}4XjoLBW{$$IWgbP4VIXd&ak~Hco3qat^Oe!3`ho} zbCgx|z7><0xvI`C^lAq zecG#kSdiQ1U$IYMuCh3jOmMo|U?U|7HCy?S-Ew*Se9WM`N7hFfmT7$qTieCt;v6^M zzY_-#hkO_n>`s|*N5ZRMcd!;t69d(z-pzz~i?|+Gtpm0}v*G(oMMEahd}A=q8mwY) zD}#CNn-vKu)oB+N2?kU{3|XGDm>4L9X1lw7zv=XNd}AB& zl?oO8<`ZC&S&JJIU1ax4=oqMoqZ0P(H^3Wwt5|eQOL&Y(4e*#}b}r!NI9&5B+FL-1RWO<;~zzELL=yN6xZ`buRRLX!>AmD1l(;N@`IEEP9}C zuET!zBM6{Cvq*&O?b|ZO8C|b)DY&8=aLe9^vaSd^W*0T9xuZHZ!$P*op;P|lw27}< zFF8u~ejdo1(G>n*%u(yVWPoib9J6l!Njy1ZUg9s22Rwj~_nu2jsQ0gb(v%xa?XQG` zq-46u(DQTSA_RaAVI!r}J`Kupe;JmdFuvsg#CS7G4g0&%%WD3r711bsf z!*Z4mUx(L$yM8f!m&grDJRu%0%|+j_{QAFq3h$5Q{^owe52||BHJH#-ead7Fvd-)WBU>BGhD;%d>FjzP zf9F1xBU_r9M5>s(haCg5C+6P}daOe66oJF&2G{ck9Z6o#N<~Fpa{uo?_%F{y&K!x$ zt;>zo7sxk>BF@I*&No}3<$9{cq0_}lxAgE%Uu3nGIhy5fJrvFcfH+HrK~CDiaw^*M zRjAgl^I)`)p#~9dNgYqf&p~KXGmvQZLqRP^&d`*%u*OBZ(*gYWA-{|)%z%rabD64k z!=8_6fa$2&5NXB~$sLvHTJ%`SpTL#G{Q4AARV`~rI@I_2_pq0Oab3DrdqFnn_@|e- zFvq;_)nvq1?U{!KWQ+O^RYk?f}-T} zkL+;FY8m+d^4ki^?e{}yN@dZ0oNOw(zkWY(|$vA}GEG~^pSy&5&ryn{V80B4Yx zbH@YHi#R1tG4ZXjnaf1THoyt6U(KZPbf!apoHPJM1st|Q9Ukav)0#Ea(F@^u2gg~M zd?>FOnR{Lz#nEoGScUL0R9-k)e?l4Yh&4tEYX;DtLzIuvY!2c} z6`FX9Jg%iAz7A5%jw}RDF~3a^VhMg{ONMIc~pT#M7ZVgzrrGd+8SUS zTi5@&V8-*W%eNMv@7OTG@F#Duyx>4>V3oe7#Sg26`S2uf&bTR@4rC3SM0cWzXQaa0 z1&T=zr;pr|t-d z_~<3+;W|L!N@6pM;kt1rhhVEY*#$sEELGg>ieNXpU86w~lm1S|u9du+$HobOtwTQC z6(+t=ORfsl5ZOmTJvipOa-PwyA~QG3W5${vXt=7hg&jI7{D1nng}mN#I=3t<>vV0}Tq<95~N@<7QIkrryJFKR07%f&6!b(9ag-HN#1eM_3Lw$A@H zrFpA`1_azbzB4|Ub2IzAS7x1qkXKX#zWBpBqJFa*0tP23YJcGv%2%T{-!*i6r>wOO1eRi?nX*LO6g9K?w0O`JLmn@|E_ykF5h<@4`=WFJTvo) znY|Bb&#t<}pNz8W#hPh;rwwzuxBJ$s+ym_JEsuJm5eM(pUy_tdkQo$m1d(ZZdt9HC zIPAaB6)_-!I}pFBIyyfQ)t5;2+~7tuH&SK$ZDThoSb-Z{Ry0}ja~8x;>^E)R_-8!W zwk$})$zgdxYAg&32VIsJw3Z~MVxot30RT9P2t+jvjWO*?4H`*r>`S+8XoBx$x zTvvU9C7W4oHmwh>1*#5d^+E2B`QESUor7l%lVqSbbo^$uG@*B@#u@)exfN0v>pbIV z0jZ?U_A}R$q|p=PGD&rL;)Av=!FI(p2Zgl;evb^6D}}8~LcY6;OGe4v$;$d+V)7ge zI8lso!d`+pUvqQ>KHWp~28NtR#`}*S-*@`g>RxA`j`h}M?Vq1*$ZWkewz6_fgR_8g z)?E;uhK)BzS4I4_i_G2)1Isl9W}EyIq1}4ReH(Es1!XAx)s-cPue;Xf6So)@S7~Tp zpiQ2&$Mm1>YI|y1-uD>T(3b1FuUg7Iku}OFQgv~i&w_V1RMDQJ+y>cWpw6pI_Q{c< ztMSSbcM>hRJvsErQ!m@;X8x-o4*}P9L~DAO-xkw66KM-kVcvxq>Mra(>qV+9SK~-LOzM9obE_k-VC7kHKh6Ke@g`TW|RX61n z)AnM+GgXCm+F$EQ3aq$)K|^7M{zTiq*Ya(OKNVRfxvDh2Ggog{PAegQCw+9pwf!O zHB(8X&OTefW!fgk3n;grPJd8FG}jRi47CUjkMx%;BXDqI*t9=pus zd@nbFIOcY0DogVK8#ChPv;#fApy1}x@asJi?gYZtmlwQ&D^|bES@cRZ#Hlb_76Kkv zp5|vqM1Aw72xu{2!Q5NQym-IVofH^v2^DzGd88l!NfC9qNFTN=TuOV^wnP`52JeoC zed)7r>P^oA2Uj0sK50j1w~wUGPB+hF|2|&$C$>6|wwIbf*4m8bS2&Y6UBn-zP(sz! z%tBK27tv2w6cSg-VieX0_pw4KKTh;;6RIo>BJVkoy|eY7_M=QHc=gTG-R;}ZR#z~a z^2kT-_i9BdOG=U@H&3TmK8~@d_fRZYFO(V{p9JPdC_U$Dml4q;PIkKw4@Q0C#hl@V zB1y-QC!LUuA-(P>P)(=;CU)t&-b2*;c!&NE#PrpWvtQ12=KGPMRBv)#rP~K3UvMQ$ zVLbGsCgd$Q!hY)C?bU1iJVG2Jqz#>esuJr^vo5k#3T|^e?-B~!PlQKS+^c)ed(REk z#am7I@$l1!`I&|KKh*PHbeecbvwSA~SZv}`oQ70tN?6V<-!LV#CESoR#?AkC)q&0wPYA5#~sKZdHK)N!%tnA za&v65>E?X-nd$z>YMYZn;KN9d`{JR)fg=>p=?V6nK9}lJe7oIkT&eFkR?gFTq@ktq zBTK<$RwrIeLc0GkJUA$Q&|0{M>>}2N zx9gAMo357CdB3sefKcYya_bpvk(Nbw^vi7^wtNTrpkeFacTQmN?Y2|x{89}KyHjlekO}?4w z=P&f_ULuj`98Yi5xTn!v4hg8RM4V;dYmdNE3`QoidwhTQKKYk~gnDi*q(V-5oNqBo ztP=e0?LDWQD=RhO6ai`sGByM?)*S58s@&X(4(x{&4(Sx~Dqr8oCws$lZ$B{!Gh?^< z#41yn%*a0Mu;t069n9eJyRBDeApd!u#wgisfIWSlri@|3(IRsaJB<2EOVrqw%D?W1 zuhBt{;6L^Y5YboE=BJY&tErJul7@%WVFrm(XFaU3A*4a!PI54gSy5Ege$5}BZo&L| zG(6=ag)n-XCRDSdH(3RV~dq2V`M$i+_ zlUb5AK0TKdS6Tjxs6I`8LP$dN;Igi>-V>ZC;nmcXQgK2o0yy31XekQJ?ogI`>A8;+ zxwM__@x%qqUd#c8fuH%Q(~;72*)J0JmNy;fZRvx)sxCgqlOFiXA=Lj7i84MBTPrP} zK*ZOXh5!Z^7Y4_o1M~Mx6GZv5=05V!bocAYC}TjQ2AaaW{+FliWa0_DdUeKaK81qX z#Zj%5#9cuLctk?4K>t49{&F9HX~?Bh8CrcRwvNzH~Aze34d1gDdeVqLc>r zLUjGvV>0R}wYDst?jA?~=Dx!H+latUpcEQJPP*g4_{1ro6;r(?+w9z9*1R%gazpFd zc{v%L@X?e;6DiR$Z%2xv?_lLbiT~7%0=u*C&5LP9ihwSCu8>ONO%=Addsm5DM3k7J z`WV!+Twk4K8RL@1v(CVJ*u9&fe#S_%zf@DO?)o;TSri^ju1x7Hj79s?za8(`|4V!1 zSkHn_LXOf?K~{N%uW!v7g5XPEQkjCS7=((4{iZD`aYC?yTB+;k)kn~!r~;=MV|!`a zjaddc(C362adk-9&0h2uia6KEbUn0z?-2&`404qwU&tdA3pqrNInW`d16g`_R4Axb zh%A}ZMPyo7p0s`Wb-0u;>%LQ)mDKED^kR!4C)V4nFJk)ZYx(j2(o#rXaMB+w-|b;! zYDqFM>A=mdpM49Ye^4CNe>^%=ZQ#4c3$+iN6@^Cj|Ex33=mtMh)N+w9-C0q^9fti+ zH`iP{8e&;%CPC4QXnDZO&BmoE<1Rb*Js$5^3?KgyR%i?D$6DrZ!V81k0m9}&?~ZtxaAh{Oj7l~0A;-9qiNintgLp|_gpwd6-6-7~`Ii51D2>)k zm=u?zAbE_qalf;KhzJ&CSD=kLYi)27KXrp{bQJD=ynX-SeXY7|SN=kkN1U}GOW3Zz zyC35KaGm2b$9`q^M3&btyui-L(GgjYkD-vY!=}%Z&Q^qk# z0+ja0CKXo{M~|Xopc|hiUSF*TE!lH9Gro(ez@N2pcoJ~?bGBy~b93#;02AE^@@sSq z@wk-ismg|T)l*?SGJLajo-g5vSE?t8dhLV94@_70S=?U;IkgKMU66;Ir=Q|#}H z=rx^_($k?xsTs6YNxVMMSh5+!A06BKkeqL2V)@J^wEnObe`k>IX}pWgV;i1#qC}%; zD5R#Q|9bj|*dL3$6-y0%hQ(GaHqIjKP~R<#N4FjnU7~3~eR$_QFfoRnts{DqH^TW< zILGntrBx|9u7-XS)8Sh868yGqGeOdV>IAc*0i(2h`h#+tH?{v}Ox=h~BK5HT#YNmz z?eVsF!8NSGpJ$o5k&^6E6h%_t?9_EybgezckU0G;qAWBj(~8J$|E~Q3l=f`fZ!u z6B@(SC3gGnu0@N)^2J6)^z7sewwPM}V>|LmyG6IA8tZ$=djA?ax+*Jz<%j)##Gmum zTD^UxBi{XxqrNf4l0PKQpDk?I!YYcpb;^mWzc;ak_E7!Pqg?$RXwKS8r=y)q`fFl$LORafY{YLH8(2aMp6u-V z-*nmABqyBiQyV9GdxD*n>&ZuYJs~F=erW6JBvd3P=k>VPY(LiPQ;jZb1=dzGZ`jW- z>-Y5PbJpR^?je3i?g#U|MxEheJyTRHgYR|Ad?_g5dQ^Y&ZEtxIP8Tm8q;c9zviR#`n|YS4nh zT(4#x-QZ6iYM*{T4ifdeamKrFx%1yoGt&NpH`Rn~=PY2>ie5dxJ{h)Juu;Q5zK_wK z!*z;;UiNCns3yz=mjKT#vu+e|u>bdm`Dxo;(|-nWMjP`dH+}R$C43J2i?Qrqx8*)v zd!c|Imaw9)_rC-E-{0OhQz)Omc7>x@kNY1U?bn6#1IX=FYqtnzkUAy{H6I+FgzFq9dT6e zFl@L=2_bZJFE~zQVT}hpw_kR zZR7Xvd)9RBc(a3!c;dgly-kJv(EZN{%cmPV);Vv}u>Aa}hT-rO5(dSin%~sl_H5gn zs^C;plxDVgjSJm(h%$Emce$Z>7^Er8?9$}%QkWs4*1SnydySD7)At^jCiu|=PI+8x zDJd&kZ%^?YZ%+qJTc(t1vc_lTy>TcFNrGNP!uthL}TK603T$Pn2C@@>}S+GTZ8<*=< zO-4|xyOX>tV=%c8ilo?d;8Q7hDu1c_zBK+inqvJ~ZJklLeA7##wY(vn!7qvIdZ%eC=Iv9GQhL zsG@oJjLXunFOlc)_;`Qio0UH2&1x7da)=V02waA;a*y%BeO~R7r|Pi*!en34zYMQm zc5EE1M&eKA=WE+<*L~YLJahY!E-oYU#JHU2LV!F1h4b6vZ@tIGe|@CR)EGor<9mvi zae|Dx98+l?<|&y^GYvg-bi}IU1uBzmMlvl_g{}=f!0}ZaaLldVS+O0HY;0|< zxe0fI_;BhcTg;il>Pil$hkOgxcG|Krn``h#Z{#zK|zMCj_c;wiYH~m%^0Lf$Ofw^A$$HKcSmPu z{6IqwFLsEi1XUE$j^wq`$`?|U;cfC_3O3@IEgt4_l3T@R>M+?OMKwE89@$CZ^%_Y; zzRaCk0<)&3X4Fob<$uJ76a&nX$K{vQKp4NPF-$qHCb4sImD{e$XuEE4d~&QzgNR4I zeQSSzzt|kRgS6X5r8MH(w~Nx5Up~!xe~Y!cPa1FBCEvbvUI?H*oou{%2cMVoA|s!! zzP|qbY(tX$#^1gDZl;41&zlqE+g=a8sVOV3Z)_Nx?JrkPTarbfKrI*#nTesmd{C#& z^zO&|w2BJsoMA|o`Jt@06_u4nXxphL)@Ek3ckkY1BQ``9B#n_A932%Si#RQx@lB_J z42Wgal2lfHDw7k{&>+C5UDq)%Fi@+4F5{nVhU_g!M@RRS&E$!Zkb9V$jrOz!WAb<6WL|J*yr9M%>Js~0@f;L7N z-`7A<5o>vQ`KuI_i>qtR$NAg8b^-M`DIO6KQAI_C zZmC-KR--9N3I63`DQI?jL+%miRPrQ}rT01&W-o`Oz!SVpFXU1Fllcxwb5ot1xMlVA}BM89^ z5bo&j_ndu79+*Dx=TFOz50hU+GcFf!8&ew_Z?FoKB(oI${QbNC>kitS9k27*avW^7 z<3~u1Gd^WyBRjk3oD`3QkK8?vnCgBNYt=%K>23Ti()7GKLM(wvihHaRzejK2r@=w_ zgoK2$gO>y$L*wI}twFc}?Tc7#KIt^rI5-MQO6(b*M>dWY4$i3Ds4xTOY-+x8SrNnX zgL8E6{{8X|BMFJlOkNZ0$^zAr?83s}sVU9$L7F#j-XwBce^#R7F#E9Q#J{w>EUnHZ zh#qbfpO%(pp~1wuv&VF^&$Pa?v+CG*LA15CMe;L^9ji!1L7{1?#+H_e337j6c(}PU zin=_IaG*V)dpW(L;xK<>GkzV$H&KopKbW#_xL@p-TwXix9h~;vS@vNer{@E(aqbU$5_^_w?$e0_Z(aQ*!J`&xkp*^NkUHpp5%KtVtO5xFu zS01qNIfb)hV|Y*ey(=w773AbTwdy}7d@?*U(>Gpjl4BRzfyWcaq$6WwL6&d~hd>0!Vn-t|XIXQVEkL|sZQNCr3 zBASN)hS}Lv=H}*5(t|@oZB}~l{rvrN+^a}T4An*Du~&NHn>WYH%RlgXUh`{dX;n{e zqbo9p~{BxW9s>jPEEYA;uLb& zzDC`#M*SRcyz6U^Cq%sL%VXVm?YInTh$XbF;;Z!+aNef?3=)k`$uocg@x=d*h&*g% ztuZ;Ujl;vk-K8j5eb)G~8e3}5^Wl_^-rl6$n&8CumX^jdbH_uo z@aV|saXecD1P2XGD&#LLB@q!MoSkS*G0j6B9v&rCRq}3D2xc#m&yZ*S-5XuHr7P(XZXq@G#_ zthaI8_W%B5N^i$dL`nQDRB!%izX5A}eC*)j;zC8^4;bR@?d^88p|je@x6`uafCwEd zS^nR?k8r6zv>24)L1E;-K3#l2Qy0ImpwEEwsM9dAxR~*W)ofqm%|)Y0ef;^wh4bZJ zyLzL?Rj_u}XHsTPPMG7pG&tUwy-9xRi*L~P>gM28xyz9`tQ-WcQ@ysvyLZmV<0go5 zMN3ELB0Kvk8FqWS{zRMDWi736h%8=;m6`chUlQNhV3>zbw%JZlnUp*>CDz zhb`LFl-Y{J8X-=3zmr-SDl}B zw2ssExWFAQuTL8LMHPF6sS*|}bS?<0W<6XT8mOZr01_~l@GTbo*`{@s+b9+NT)pH9Q3)q_DG*?>H!a0po_ z2r3Tki`J5ol1``d{^Qlwg$9~$=qY0dn z0uI1(Mo4Crbrn!2JADSc2G+W9WHb4p%KGN&1cJdh_z!I5=0s&s%GIU`R8-Cn6D{6o z0^)jl^U=!`5|CtAmy55i$p=l@H|(bYhcdzhy6R76x}k0Zls+d88y+0|eY`b^2q$Vh zL3Zip)q}mYj4)Uq6>|#<>)L7HKu{@pT#p$V&w8!6?bZg6yCf`T=do5Uq3#zNDdvNqNO95EWZm*R`)HlnY0qN;vGKt*3 z5&1ngHlRNNv9;tf9aL`O_ zIT)%N8yfn!UmWkr-WbyTa&U0Ks1NW1DC3tKjo0)gb^Bc@lvTyTu||n9)H-Kf zs&IMP<>gWRp4SJ_(b0ypitGegpe}bjPHIQ7Zs6-?%&?`&&K;bMBlxetT5q#7l0d{HsLnB78JZy>$tO0R5kN+<+|^% z7s{nCg+y?0@a{p%ja$YZPP8_PCDi3{mIPhi) zU`t6!iJ6tPNfzL@Ao&hkTMl9|3N>pm8)y5OJkQ>0J1@tW_1{X!W&vW()v)*2vTjcl zio4ThX?t0OxFV=)wF*;=uvgcBw}qVqO|1I|F{;}0zr5~@qlo_JGkAx6md$QkYNUyO z3=*=*yb=(gbfC>^+kls>*_Uxm-ebK;*R!FYrNr0OXzVb+{cRevi^NtUp{)NaL zSXfCYWSmQrTMkiDRMdp*M(U!G6b*YJR^LBNa!y%tfu-@#mxh)$Ek8dnA%WmM=d;ojFSK;wPq{I-+^(-hu6n*rehn2myj11XL#`CI%G^&Gz4) zJPU>odH3P!rnJWrboS=Q>qWcs9kK&Z)&DFRIl&!Ar?s(8x#9{uhY26@800sg5^YiJ5T{l@ct!9FIjmaa#fg1sc zcZvQ1iPUg&wL>m}nNEWl)QX3XKVP78MThq4l37|>8vYfE0-}a_G~Oi4t&j26tyzmq8sL&x~;a^|K$jBbjNuzstKTM;4ai}e{w(fh#8E@EiA@;XiR%qnl zqWI*3Z>PKRuewRn8{-w{NP%iOAYoir`Q`LR)Dj7W#OE7hccTEQ zD$s=&6F0QxVPad`H09$A$NeS|nEka4uzFO|2`yhR$dmIcjt&ko%;4)Q?%!uK|3d@o zEs`UmriS-nqQZIpIw3Ze%8kvAmuP!?8@Sy!Kp6uQ6YA@}=4Nld$IphUtqtCP_@G|> z@wtG20A7eFG&-TunGB=~rk8PYap7`taq+)=`QXVD(IOSLk8`i<>g!qJr8<^vV4+N> ze;(Pc8aq4lYb93BA)MYbu{|9NBr$ugZ-sWsc)Q*t*mx<`@8!4WtGO5 z)XRhm#ekOxi;0zgoNvlOVYITcBH*%=5*9`Vx0dHDObu*%2w3tCh>_u&D+s=9Qq2P}S?nh4{3iQ$F85!Gt|Gve<#Ps>|XF0lG)DX6KV z;^S$Teo{bj94J(0%a)24Do`o<4smiYG&+h4Ybzuqgro+Xtf8R+wiETAPo6MnR?BIoXqQj}p#=Yu)#K8MjGSCPPYDp{GI7x=6ESMfwoYhx9{IGK=i%tO_mVy2IXm2k<6=e;-dJ~cU{FIC5Z@GyqEKZHX zRzf)3mUG351UxRk0^kp3h@ll07FOG>V^&$s;;>q% z7HbNYjNXjmh}JjJ*CGDgaP0;noiX`hf0aB>Q!onEC=$(65_l>+avVEg%I@UqnlWhV zba~1mBO{YeSHy`==KSW5rFRTG*)^SPuu1#fq}q@v2OjDI^xE@jXlRgL7pXL|>oF;x zO6>0KoppCM4-F~sI&6xU8~1|MM9sny0p+;f*Nh`~K4E>TrX*>q=kJTM#z* zobU0#c~}KWXIiNtaD8gfaC3cOLbU?UVn9$(1{^XvHWnO=-J#6hgYbw5K}}6UkOLxF z3f2cJ(g=IuXk%1FL;=jJfRGTQ%hSCZF@ZZ^A4+{?_R5z2rl6(Oo3C8(B`hpQK0~x8 zktYff{yWBvR0FYM;QR@LlKrG+Z?L*L!!r+m};a<-(wfq|F@*Ks-?znhxUL!Wp+p@I-z zaJ1(oa&T~H^1Njp`Kc1Ea%5ygNLBmXk=9HTNQAX@>1mZo{Y|8(t-L%24h~NB&z}uS z@5T&Sp)~pUgp`-FJ@j43bJmiRdku@t!omV?u(`f;2KK=B@Zm!|4s-13pY|CB?V(fk zu5sS?(E^2y;K)E5OxX!4G+scR;^yWKkBg(mw7|69E8K{#qXl6X-f0NQennZoM7>(^ zTQ@i3*iwZ}0``!hQjjuq%J||%mQJGwDiJTX zr>E!kbZz_oayLS6fx!el9?YjtMcPA&&VcXY@!I8zU*?XOGrV|l2FrZ9o|E#=t)wC@ zJ1gq}Yg=DmDCmDIct4yhkXpl$eC4H;m7xMvCV*_8oE&N};V8(-y+J}4eE87&>Us~Y zbd-?KAs9e1B03uJ4ib;k=YxZTb~`gk5>Zs8z$##&V2|LSz@BtG&M2v6lMC6}BL6H* z*E;SU9=iSGQC3w|g#Cg`r;GW=6hSmJMbE?}Rx+BQkarzSq76Lbywkky$zHm?kQ+Zcq+0xPy9v$5ReCrduI9X-23VN`gq9s_ATu@Mu zLF=bd$0y01SlD;TtVk$Cq9FuF*cfD;!LXEYJt2T6{QQX+zk>`d^jXEK4#lFyQe!F$0F4UR*t&O`ej?#@`oklC3qit1b28V`Kb^enmeV8Zy0amR#xWO^USZ2|`5QsHosrj20UK zJqDm;dq2hu63cX^jt5Xdv%yW>$w4ZCJpj~MijS{vz?r(Mrza+uT8JGQ8$&?6goK11 z6bQ|qb^$U;ynFNhB$#kzP{nxM&)K2CK+F!78NG#1g~$d^)KB)+RWuX|41D~~Oo^zq zKY3WNm8+L$2cXrnWa3Qh?K$XT0%04%VSnV+)H=?OHpWX0#FBW(S_2H)f&X=M_h- z!x_Q4ycVh60O9Dix3~_!(E|OV?xm@^G>JfbhTHdgIm!$Gn0%kEtcbSA`=s(U;$3&ywOk|Jwi|`h-7~VbF!#T)4zrB zYA3|}PQSp*tA-{J2(Ut}+E~l!pBO|$y<3x2@hk?SNe2m?hGf7y0oM`e201PW_f^3a^1&mBhO-&ax!|T_d!u@}`o-jgT1wA4F z4Sew3n@Kz`pVy%!`zSqrWz>c3=I&l&v-}X8T`Xl6re|TmZh#fs`IW@y$PGaKc5>6! z);9i?rvMOELnEU-@N@zLJ75peGcw{iKM*`7U{3||_Uh&`${&pcDD$bwy=3EZd?0FR zB;R6UM5-*Oz{sAjEZ#mo$b>{h&X8i(j*ce4n(a2nBEUno zUhay)z1Q9E9WAv!kahR6_1uCU<{3ZqPm~AiSm;+`kpQt9po~^#Gc4huhQ7Dm6l6au!c! zJx!pcR9#aO5gFOE)EQ-MYiqDO_gkp>6|_V@0cvK*(TGb*Vg}&VX!b$RFtE0k4}c6l zKOBe`=no_mR8-_Sc;n&DEQzbD>*m%LkJBzGtN{of5|C$~y}c2(Hsp;P)RQ`J|6m^> zeOh531;dD6`t+Ua5QUYMu})N3g_V>raaqkQ0YIhylxIbHjEYJDIs`-zbcjmJ%9@6U zqri(ri3H^q-P0ol6uIU1%*cpx-7DTCnT|dk|opxrCDDxF@-T|{RDY@2V^q{7J~av2+;$i)|2V!=uD@oS%`^= zGqbZ<&4wtzA0kIdg{p;)fg$MRRE>B4(W6Icwh&mMnvA#`uc2(V5WEBJVQg$n^w`jB z#@W>si-;(svGEmbJVGuE=g2P3&+o<269GjtuY{+XJ$n@I9~_LzLKPXT10oi*RMYr) zJX{*5^&Y8KV_Y0=e0;oSjg4xAcT1gaW`1!YzsyKvKLcC*vZc9 zDujekp3>hMkGvo`swnVkMZa-=KwuvbO9~1K1ZV*E{P^($p`h-g5$`UBy;_)`7X*zv zR%wB&qM`y8N}57mM|Zcu>FzvKcVKu;fb3six^exQ!`(9a@1wG`A z(UPEtrxV|Q{P5^nQ_F8DQ7s{cm~?V-+FctYhZYp$-wP8#y7%G91(I2092_yt{%Fu2 zJ?bI5yIutb_}WAzJwb>cG;m-?-$3&+>7@SW&xF)eQjJP;U!d-wI*h;nOk0N1>>nngA%g>32!BDc`MwVZ(Ex=G)DbJ7Hq>jxqJ!B13LT0RR4fo$ zL6!(_pOBDnZLBmS{1+Vq1Ga0zA)|!w3?PLU)FVntN(8@wvhv4S23$6MCRRLDLSDNs zB0)Iw|Ncn=dh9gb2!L7%CgO_)zOuNu*bJM4&;V@4J?IA3m#5ayzFVJdys^bH2Udir zWG+YlKuVl~13&iT!!uZEtH!I?xjEgz#|Xmo!*V)MyJQvOpslUV@6R74ME?==AENS% z7Ha_mLbxisyQHkFkpQQoG4Aivc!qHo$9z)oe1*@qqF`3LOGGJaLeGU|i%t zqZW9oH7H4w{&$FVv_E1mH;>H2uow$Ou3sR0f^bScE^-C%>Ky;qQZ8tdtm&kZ?YA zyQHp;A3Q_RAec7I7? z@VV+UkXm3jq31f^U}0g^07!%EW4HK=ii?X|tkc-n6VHagMc{%9n!xlS1_=>q*LZWy z2RRRvdwy~84a^@Pq|Y}lE1X4Un8E-s6{%W^>gQx*ivpQ6SY@RHl8nb@2?Z9ry0*4) zx4x>P0zr#^{dx`7O7b%-z>>to#E%XRZL_l^fD?wS_+cck5RM$ssxP+%Vc!5ZVRP-) zhwcKV(b3Z*^!KZ)LoB!+2)28V34~2ebH_F^4=U~Im6erG&(2uD76mYnkdaw|ASfaK zKy{@BW<@-^872_E9Hy`!VnzP{ood6CJIVzMVzVR;StsKE3lJNDYryR*thk^;K&Gi68-tYz?6AzR z6B7xngf^cCSU$uJP}cyd5OR-VH;w=wG9DfR5HlPBJYDy!%+U-G(S6rsoAI1o&ljUGwB zsGgU|DJivpzd!jU=n)zIoJN~zfN&5z!B7$cmTF-!AK@_M zvYli;l+BgE6n|H|ra*wNsDuQB zl7_O5fRx}Eb{vRMByL__u$z6oz3%|t15Zg^c}-5a2bv3Pbc7;-&BcLu210uM_bUPF zJ*blb!NDz1FhKcPLpTbi!z2~PvuEA#Zt(Qse9{&2=fbd>i|v_y;TwLSrM$3_U3eKss95=imVX56%$} z2et+YJyvG)==t;KO{w+O)nAvE3(bx0dD1s2!R+xwWFo*v=G!gaL>mgeSuff*b>Acq*xsfd_MVllY) z#t&66V^G)1iYXtXM7v({=KAUe^MQi8dO$`7IlPmlH-V$skQJfPLCXb$6IwC~$7+XI zcG{{O750a_NAB*_pAgTwpAbRo5W^w`&=Yt!^e+9OIuGT@lAIr|dxJd!PBhpaY{2$J z!-%CF95^Mj_Vc#u;{-LZD8XYpn2ou?5t;H z6$bDI!vKrmraP_gcmneR{r2DnT5qg~bW9cWhQ$G=)oK&;<-UJQ`k-WQtu7&Vbg?-Wkq%iwmdK9!9bXH4fMmjJx=9x@l_@A3Iy3 z{@*OXC#S|3$*i%P6p9U4GsZ9ZKv@^p)@T|KhBCsV16D6sd>a`VS@**SX>f$Vxxkjt z^%W9?fB>GKYSNz!*!o~dNaNWpzb8Z*)PEy8(gfVk(TRy=Kz^=WCkeRo7HL#A!5nf^ z1W1bpsKdONg=y)4MGTdIJ9Kt9ESqv*eDh2};RxDRggiFLMJn`d3vdP4xVVe! z>$DvzxwL^iZl}ypP@pOc)$M2$#X`Qp&>>j$&iQwM)FuMDumUVVT2{jhVv6T=&A{&x zqBrqyA5iOk6z7tOl{pwbTS$O*i8ILb-JKI|@O5B)fp32XMFBD=k#OLmuAc$9J)QpXf2Yk&!_;SdY&Mz?75yA~j8pwfw zS%6LmP)soQnwpz=JT9xhyacU;Xd45LHiKZZ%SZI2_7+4S^W{vKqA{1iXx6uaz)W{C zgI)=AfuuLRhD^XX7XB>-zZB8E{JQtt@nl;|;9`>vl!&ir2F6_;pwD;bpaqI3>(H*k zFag0E0m}$73n9^HXqX8?V6b*~ubm{w5hh}B>a}6o=x*cH@iU+$ZvWns+jmJw9D}oO z4I+p4MDC(ST|+}4>eCBGMg(aCK8x`QV8;edRq!W2{e$As4SBj4RAPR-TxT)ZLJIu^ zkPFcIdJcUZQ2;MkNHq-&$V(pZnjx6=A4j*nPLtDRjjy&|yf<2+gL~n}u0tZ=Ca$J7 zrl^e%m}Iv#q1@63T38?0EwDV02S+_@66WS~hzSFzOi->-0;*wzA`V)0{K~qNtvpus zM>sY$gdpRksNNKelo|IDly0l5tM5T%fxp-aB0*#{L#bkAbCW^8DKZk%a;8oKI)|+q zs^1K|`udDv)(dPqUCwW3S0}SIo;T1f7O>y$G-TC-Dh;{`sx+`*R=p-D-C^HpI#TwWqe-z0?Y>Fa~O0 zU{eRoc7_ta?COf4hbdmqs#{Gmb%AZ%gR+OwgC8Oa=L7N+ln}5%ZEbBI zKR?zF6YLnw5;xf+=Dwb7Fij0+0_N-oc7nbAeM~{}KNL1VFIrzV(#%C=#%gM6ip$D+ zc8%?;cK`{1rjW3(FbbRasiSV55?yCck0I!4&3YH*%Zs4TT%F!}TN9O+FtZ7qCP!w{ z1w<3L9}6FWx?DDe-VbIz@jz)2PapF z4h$UM-z%O%pI(PV`zg`7&i%RtCj%2@zrpy}X}FM~pn*<15LfVXF08OffpkEXdpP_X z(jXu-tMPW8Yc5EP1?J7QHF3mdH@-dDUR&#cr9VC2D3LE1fh$4cxX+Y3XcQ+9my*I1 zi38^f1{2g;=nRO#IAfm=O!WoN)Vl_?$OFO3kx9CZ{vbU(933*b#^dT|^){5jh%rDt zEaPZ|Lj}iqKUEz8mKuyb9#^GkL6HV4n-VxBI2)K)SQ{N7P$(DX=1BSPW_pJLcY-VU zJQe;$-Hb4eYHw?0h04$q*cpr|MgJnPKxn@^;9bl3*J~+kWJmAj#Cn)+&U@F6~9-_Z2*e4kvBSF^* zZ9|w239_GsE-x(NjLXIg5G+1tWnhSjS6fb01i(0go|%~hbb|3-d0akvoJ9$uCNeT| ztj?KD&UVbv#zr67nxDs52|}Rcra?;_#zVlebUs*lkQW1e{{7$gNl;`7V14TAUqbnD z17orO>H`R5gcmcA<}IE-I`az9Ng8HxeSZ9?kd-_*?PLsv-2pii(P;?#4^7|!^UBXi zFtUR%px^=!Ss*E?KilArWRK~hpz0NES1A2j{3Z$J6;Lqv`1nAUi%Upgc;Eo%HUUbu zaDINW!<1Q^Nl`#yWnf`JZzx2-`5qy2;SjLmp)D=oWxR)3fu|Q*2*_ONHGoOVrcp^4 z8Dvkq84Y4&7=*&Em%acSvg0BleSh&Fy{ z3Q|+lKXA;ze7FPS2}46e;OSEf2x!}V05bw;h`79b#MfsEV!B&n(6v&~)MUs!$()Tu zVOb|Clum}>DnxrvP|$0)y$pyj*zR_20M{(5Stj(qVMHc68arYD?p)9vFEf%#=EsL6 ze_lmbl%%kUMM4tRX$Za>VvMlibm0+1l(1@Kh3C(oKOd5?2cRPQ?O^S}G^!sgI^62U zRjS>Hs$hQujIkPIIm`rPW@QaRI~g$xvl6O=x3sl|=##5en#YU{XZj%evcS<>p@Rf@ z4Ol{rSwEDR1fCqA1mG!Rpbflu7!tn&YYp}P_|@(EVp!bQi#AbS6Zu_!lA=qbV+5qX;w*=DM!jTzjh(38ON z7O!vP$OT+r5R%YDv<1onZvWlq%JTApVA6sQSM-INqHMIY^8r`}9@l3k5MiZNRk#Xv zP)L`-n!;(Ia;YaS9}o8T_g~BY#_!AuG@5J$Y=|`AS_l^>^rWzmflBJ@uV8$(7R^HG zdoj_`2on!VEtNiBCw(g{2Y6mE)uCQvV?=~0yzgfQQW=U2Vsa2DH58UN8X6Nik;qXG zbHK^-)?{3TJrmMDc20@n{fcPerKFJXCWV4U;tS5{+{Q-GZ(UuGItXK1AB};Y9%~0g z%FB;-c9|dsK-Gbs^a2e7ZwB2UQpYyLYs8F)^7!#nt=7ugu1vN?n3NP5)pT>?2fhP{ zErN|dGV==2l$)vNgZbGsLGOF^$wev!i2FVCWs4t!j)RO)hyzmbSdSqs1CY2&U1nC0l0Ygv=Q*=?C zFtVf%zXV?IBmJ!sJ7nCaPxt=5vqIQC@W}|v3uYMs4NF&mDf2JiC_$RQbqcS_!%Y%v zx}cew$H&#fD=HV+0+ynSrHfR^DvODG^fy zxw$l;F(kp}sb&H;4Ad4}Z!84@m-+@7xl4Y;1X!^aA9Q6uL+1|^#wy?q6T4dU6N{4i zV}ggip8>VL1KH&L&;W)y5Re!cITG(ssz+7$