-
Notifications
You must be signed in to change notification settings - Fork 68
/
setup.py
165 lines (144 loc) · 6.17 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import platform
import sys
from setuptools import find_packages, setup
from pathlib import Path
from shutil import copy2
from zipfile import ZipFile
import urllib.request
import urllib.error
import socket
if sys.version_info < (3, 7):
sys.exit('Sorry, Python < 3.7 is not supported.')
# NB: the following code is duplicated under tmrl.tools.init_package.init_tmrl,
# don't forget to update both whenever changing RESOURCES_URL.
RESOURCES_URL = "https://github.com/trackmania-rl/tmrl/releases/download/v0.6.0/resources.zip"
def url_retrieve(url: str, outfile: Path, overwrite: bool = False):
"""
Adapted from https://www.scivision.dev/python-switch-urlretrieve-requests-timeout/
"""
outfile = Path(outfile).expanduser().resolve()
if outfile.is_dir():
raise ValueError("Please specify full filepath, including filename")
if overwrite or not outfile.is_file():
outfile.parent.mkdir(parents=True, exist_ok=True)
try:
urllib.request.urlretrieve(url, str(outfile))
except (socket.gaierror, urllib.error.URLError) as err:
raise ConnectionError(f"could not download {url} due to {err}")
# destination folder:
HOME_FOLDER = Path.home()
TMRL_FOLDER = HOME_FOLDER / "TmrlData"
# download relevant items IF THE tmrl FOLDER DOESN'T EXIST:
if not TMRL_FOLDER.exists():
CHECKPOINTS_FOLDER = TMRL_FOLDER / "checkpoints"
DATASET_FOLDER = TMRL_FOLDER / "dataset"
REWARD_FOLDER = TMRL_FOLDER / "reward"
WEIGHTS_FOLDER = TMRL_FOLDER / "weights"
CONFIG_FOLDER = TMRL_FOLDER / "config"
CHECKPOINTS_FOLDER.mkdir(parents=True, exist_ok=True)
DATASET_FOLDER.mkdir(parents=True, exist_ok=True)
REWARD_FOLDER.mkdir(parents=True, exist_ok=True)
WEIGHTS_FOLDER.mkdir(parents=True, exist_ok=True)
CONFIG_FOLDER.mkdir(parents=True, exist_ok=True)
# download resources:
RESOURCES_TARGET = TMRL_FOLDER / "resources.zip"
url_retrieve(RESOURCES_URL, RESOURCES_TARGET)
# unzip downloaded resources:
with ZipFile(RESOURCES_TARGET, 'r') as zip_ref:
zip_ref.extractall(TMRL_FOLDER)
# delete zip file:
RESOURCES_TARGET.unlink()
# copy relevant files:
RESOURCES_FOLDER = TMRL_FOLDER / "resources"
copy2(RESOURCES_FOLDER / "config.json", CONFIG_FOLDER)
copy2(RESOURCES_FOLDER / "reward.pkl", REWARD_FOLDER)
copy2(RESOURCES_FOLDER / "SAC_4_LIDAR_pretrained.tmod", WEIGHTS_FOLDER)
copy2(RESOURCES_FOLDER / "SAC_4_imgs_pretrained.tmod", WEIGHTS_FOLDER)
# on Windows, look for OpenPlanet:
if platform.system() == "Windows":
OPENPLANET_FOLDER = HOME_FOLDER / "OpenplanetNext"
if OPENPLANET_FOLDER.exists():
# copy the OpenPlanet script:
try:
# remove old script if found
OP_SCRIPTS_FOLDER = OPENPLANET_FOLDER / 'Scripts'
if OP_SCRIPTS_FOLDER.exists():
to_remove = [OP_SCRIPTS_FOLDER / 'Plugin_GrabData_0_1.as',
OP_SCRIPTS_FOLDER / 'Plugin_GrabData_0_1.as.sig',
OP_SCRIPTS_FOLDER / 'Plugin_GrabData_0_2.as',
OP_SCRIPTS_FOLDER / 'Plugin_GrabData_0_2.as.sig']
for old_file in to_remove:
if old_file.exists():
old_file.unlink()
# copy new plugin
OP_PLUGINS_FOLDER = OPENPLANET_FOLDER / 'Plugins'
OP_PLUGINS_FOLDER.mkdir(parents=True, exist_ok=True)
TM20_PLUGIN_1 = RESOURCES_FOLDER / 'Plugins' / 'TMRL_GrabData.op'
TM20_PLUGIN_2 = RESOURCES_FOLDER / 'Plugins' / 'TMRL_SaveGhost.op'
copy2(TM20_PLUGIN_1, OP_PLUGINS_FOLDER)
copy2(TM20_PLUGIN_2, OP_PLUGINS_FOLDER)
except Exception as e:
print(
f"An exception was caught when trying to copy the OpenPlanet plugin automatically. \
Please copy the plugin manually for TrackMania 2020 support. The caught exception was: {str(e)}.")
else:
# warn the user that OpenPlanet couldn't be found:
print(f"The OpenPlanet folder was not found at {OPENPLANET_FOLDER}. \
Please copy the OpenPlanet script and signature manually for TrackMania 2020 support.")
install_req = [
'numpy',
'torch>=2.0',
'pandas',
'gymnasium',
'rtgym>=0.13',
'pyyaml',
'wandb',
'requests',
'opencv-python',
'pyautogui',
'pyinstrument',
'tlspyo>=0.2.5',
'chardet', # requests dependency
'packaging'
]
# Dependencies for the TrackMania pipeline
if platform.system() == "Windows":
install_req.append('pywin32>=303')
install_req.append('vgamepad')
elif platform.system() == "Linux":
install_req.append('mss')
install_req.append('vgamepad>=0.1.0')
# Short readme for PyPI
HERE = os.path.abspath(os.path.dirname(__file__))
README_FOLDER = os.path.join(HERE, "readme")
with open(os.path.join(README_FOLDER, "pypi.md")) as fid:
README = fid.read()
setup(
name='tmrl',
version='0.6.6',
description='Network-based framework for real-time robot learning',
long_description=README,
long_description_content_type='text/markdown',
keywords='reinforcement learning, robot learning, trackmania, self driving, roborace',
url='https://github.com/trackmania-rl/tmrl',
download_url='https://github.com/trackmania-rl/tmrl/archive/refs/tags/v0.6.6.tar.gz',
author='Yann Bouteiller, Edouard Geze',
author_email='[email protected], [email protected]',
license='MIT',
install_requires=install_req,
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Information Technology',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python',
'Topic :: Games/Entertainment',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
include_package_data=True,
extras_require={},
scripts=[],
packages=find_packages(exclude=("tests", )))