-
Notifications
You must be signed in to change notification settings - Fork 34
/
frontend.html
892 lines (709 loc) · 48.5 KB
/
frontend.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>The C++ Frontend — PyTorch main documentation</title>
<link rel="canonical" href="https://pytorch.org/docs/stable/frontend.html"/>
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!-- <link rel="stylesheet" href="_static/pygments.css" type="text/css" /> -->
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="_static/cpp_theme.css" type="text/css" />
<link rel="stylesheet" href="_static/collapsible-lists/css/tree_view.css" type="text/css" />
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="next" title="Library API" href="api/library_root.html" />
<link rel="prev" title="Installing C++ Distributions of PyTorch" href="installing.html" />
<!-- Google Tag Manager -->
<script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src=
'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f);
})(window,document,'script','dataLayer','');</script>
<!-- End Google Tag Manager -->
<script src="_static/js/modernizr.min.js"></script>
<!-- Preload the theme fonts -->
<link rel="preload" href="_static/fonts/FreightSans/freight-sans-book.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="_static/fonts/FreightSans/freight-sans-medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="_static/fonts/FreightSans/freight-sans-bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="_static/fonts/FreightSans/freight-sans-medium-italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<!-- Preload the katex fonts -->
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Math-Italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Main-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Main-Bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size1-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size4-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size2-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size3-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Caligraphic-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.2/css/all.css" integrity="sha384-vSIIfh2YWi9wW0r9iZe7RJPrKwp6bG+s9QZMoITbCckVJqGCCRhc+ccxNcdpHuYu" crossorigin="anonymous">
</head>
<div class="container-fluid header-holder tutorials-header" id="header-holder">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.org/" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li class="main-menu-item">
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Learn
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/get-started">
<span class=dropdown-title>Get Started</span>
<p>Run PyTorch locally or get started quickly with one of the supported cloud platforms</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials">
<span class="dropdown-title">Tutorials</span>
<p>Whats new in PyTorch tutorials</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/beginner/basics/intro.html">
<span class="dropdown-title">Learn the Basics</span>
<p>Familiarize yourself with PyTorch concepts and modules</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/recipes/recipes_index.html">
<span class="dropdown-title">PyTorch Recipes</span>
<p>Bite-size, ready-to-deploy PyTorch code examples</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/tutorials/beginner/introyt.html">
<span class="dropdown-title">Intro to PyTorch - YouTube Series</span>
<p>Master PyTorch basics with our engaging YouTube tutorial series</p>
</a>
</div>
</div>
</li>
<li>
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Ecosystem
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/ecosystem">
<span class="dropdown-title">Tools</span>
<p>Learn about the tools and frameworks in the PyTorch Ecosystem</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/#community-module">
<span class=dropdown-title>Community</span>
<p>Join the PyTorch developer community to contribute, learn, and get your questions answered</p>
</a>
<a class="nav-dropdown-item" href="https://discuss.pytorch.org/" target="_blank">
<span class=dropdown-title>Forums</span>
<p>A place to discuss PyTorch code, issues, install, research</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/resources">
<span class=dropdown-title>Developer Resources</span>
<p>Find resources and get questions answered</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/ecosystem/contributor-awards-2023">
<span class="dropdown-title">Contributor Awards - 2023</span>
<p>Award winners announced at this year's PyTorch Conference</p>
</a>
</div>
</div>
</li>
<li>
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Edge
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/edge">
<span class="dropdown-title">About PyTorch Edge</span>
<p>Build innovative and privacy-aware AI experiences for edge devices</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/executorch-overview">
<span class="dropdown-title">ExecuTorch</span>
<p>End-to-end solution for enabling on-device inference capabilities across mobile and edge devices</p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Docs
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/docs/stable/index.html">
<span class="dropdown-title">PyTorch</span>
<p>Explore the documentation for comprehensive guidance on how to use PyTorch</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/pytorch-domains">
<span class="dropdown-title">PyTorch Domains</span>
<p>Read the PyTorch Domains documentation to learn more about domain-specific libraries</p>
</a>
</div>
</div>
</li>
<li>
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
Blogs & News
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/blog/">
<span class="dropdown-title">PyTorch Blog</span>
<p>Catch up on the latest technical news and happenings</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/community-blog">
<span class="dropdown-title">Community Blog</span>
<p>Stories from the PyTorch ecosystem</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/videos">
<span class="dropdown-title">Videos</span>
<p>Learn about the latest PyTorch tutorials, new, and more </p>
<a class="nav-dropdown-item" href="https://pytorch.org/community-stories">
<span class="dropdown-title">Community Stories</span>
<p>Learn how our community solves real, everyday machine learning problems with PyTorch</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/events">
<span class="dropdown-title">Events</span>
<p>Find events, webinars, and podcasts</p>
</a>
</div>
</li>
<li>
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="with-down-arrow">
About
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/foundation">
<span class="dropdown-title">PyTorch Foundation</span>
<p>Learn more about the PyTorch Foundation</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/governing-board">
<span class="dropdown-title">Governing Board</span>
<p></p>
</a>
</div>
</div>
</li>
<li class="main-menu-item">
<div class="no-dropdown">
<a href="https://pytorch.org/join" data-cta="join">
Become a Member
</a>
</div>
</li>
<li>
<div class="main-menu-item">
<a href="https://github.com/pytorch/pytorch" class="github-icon">
</a>
</div>
</li>
<!--- TODO: This block adds the search icon to the nav bar. We will enable it later.
<li>
<div class="main-menu-item">
<a href="https://github.com/pytorch/pytorch" class="search-icon">
</a>
</div>
</li>
--->
</ul>
</div>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<body class="pytorch-body">
<div class="table-of-contents-link-wrapper">
<span>Table of Contents</span>
<a href="#" class="toggle-table-of-contents" data-behavior="toggle-table-of-contents"></a>
</div>
<nav data-toggle="wy-nav-shift" class="pytorch-left-menu" id="pytorch-left-menu">
<div class="pytorch-side-scroll">
<div class="pytorch-menu pytorch-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<div class="pytorch-left-menu-search">
<div class="version">
main
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
<input type="text" name="q" placeholder="Search Docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="installing.html">Installing C++ Distributions of PyTorch</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">The C++ Frontend</a></li>
<li class="toctree-l1"><a class="reference internal" href="api/library_root.html">Library API</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Notes</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="notes/faq.html">FAQ</a></li>
<li class="toctree-l1"><a class="reference internal" href="notes/inference_mode.html">Inference Mode</a></li>
<li class="toctree-l1"><a class="reference internal" href="notes/maybe_owned.html">MaybeOwned<Tensor></a></li>
<li class="toctree-l1"><a class="reference internal" href="notes/tensor_basics.html">Tensor Basics</a></li>
<li class="toctree-l1"><a class="reference internal" href="notes/tensor_creation.html">Tensor Creation API</a></li>
<li class="toctree-l1"><a class="reference internal" href="notes/tensor_cuda_stream.html">Tensor CUDA Stream API</a></li>
<li class="toctree-l1"><a class="reference internal" href="notes/tensor_indexing.html">Tensor Indexing API</a></li>
<li class="toctree-l1"><a class="reference internal" href="notes/versioning.html">Library Versioning</a></li>
</ul>
</div>
</div>
</nav>
<div class="pytorch-container">
<div class="pytorch-page-level-bar" id="pytorch-page-level-bar">
<div class="pytorch-breadcrumbs-wrapper">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="pytorch-breadcrumbs">
<li>
<a href="index.html">
Docs
</a> >
</li>
<li>The C++ Frontend</li>
<li class="pytorch-breadcrumbs-aside">
<a href="_sources/frontend.rst.txt" rel="nofollow"><img src="_static/images/view-page-source-icon.svg"></a>
</li>
</ul>
</div>
</div>
<div class="pytorch-shortcuts-wrapper" id="pytorch-shortcuts-wrapper">
Shortcuts
</div>
</div>
<section data-toggle="wy-nav-shift" id="pytorch-content-wrap" class="pytorch-content-wrap">
<div class="pytorch-content-left">
<!-- Google Tag Manager (noscript) -->
<noscript><iframe src="https://www.googletagmanager.com/ns.html?id="
height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript>
<!-- End Google Tag Manager (noscript) -->
<div class="rst-content">
<div role="main" class="main-content" itemscope="itemscope" itemtype="http://schema.org/Article">
<article itemprop="articleBody" id="pytorch-article" class="pytorch-article">
<div class="section" id="the-c-frontend">
<h1>The C++ Frontend<a class="headerlink" href="#the-c-frontend" title="Permalink to this heading">¶</a></h1>
<p>The PyTorch C++ frontend is a C++17 library for CPU and GPU
tensor computation, with automatic differentiation and high level building
blocks for state of the art machine learning applications.</p>
<div class="section" id="description">
<h2>Description<a class="headerlink" href="#description" title="Permalink to this heading">¶</a></h2>
<p>The PyTorch C++ frontend can be thought of as a C++ version of the
PyTorch Python frontend, providing automatic differentiation and various higher
level abstractions for machine learning and neural networks. Specifically,
it consists of the following components:</p>
<table class="docutils align-default">
<colgroup>
<col style="width: 23%" />
<col style="width: 77%" />
</colgroup>
<thead>
<tr class="row-odd"><th class="head"><p>Component</p></th>
<th class="head"><p>Description</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p><code class="docutils literal notranslate"><span class="pre">torch::Tensor</span></code></p></td>
<td><p>Automatically differentiable, efficient CPU and GPU enabled tensors</p></td>
</tr>
<tr class="row-odd"><td><p><code class="docutils literal notranslate"><span class="pre">torch::nn</span></code></p></td>
<td><p>A collection of composable modules for neural network modeling</p></td>
</tr>
<tr class="row-even"><td><p><code class="docutils literal notranslate"><span class="pre">torch::optim</span></code></p></td>
<td><p>Optimization algorithms like SGD, Adam or RMSprop to train your models</p></td>
</tr>
<tr class="row-odd"><td><p><code class="docutils literal notranslate"><span class="pre">torch::data</span></code></p></td>
<td><p>Datasets, data pipelines and multi-threaded, asynchronous data loader</p></td>
</tr>
<tr class="row-even"><td><p><code class="docutils literal notranslate"><span class="pre">torch::serialize</span></code></p></td>
<td><p>A serialization API for storing and loading model checkpoints</p></td>
</tr>
<tr class="row-odd"><td><p><code class="docutils literal notranslate"><span class="pre">torch::python</span></code></p></td>
<td><p>Glue to bind your C++ models into Python</p></td>
</tr>
<tr class="row-even"><td><p><code class="docutils literal notranslate"><span class="pre">torch::jit</span></code></p></td>
<td><p>Pure C++ access to the TorchScript JIT compiler</p></td>
</tr>
</tbody>
</table>
</div>
<div class="section" id="end-to-end-example">
<h2>End-to-end example<a class="headerlink" href="#end-to-end-example" title="Permalink to this heading">¶</a></h2>
<p>Here is a simple, end-to-end example of defining and training a simple
neural network on the MNIST dataset:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="cp">#include</span><span class="w"> </span><span class="cpf"><torch/torch.h></span>
<span class="c1">// Define a new Module.</span>
<span class="k">struct</span><span class="w"> </span><span class="nc">Net</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">nn</span><span class="o">::</span><span class="n">Module</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">Net</span><span class="p">()</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Construct and register two Linear submodules.</span>
<span class="w"> </span><span class="n">fc1</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">register_module</span><span class="p">(</span><span class="s">"fc1"</span><span class="p">,</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">nn</span><span class="o">::</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span><span class="w"> </span><span class="mi">64</span><span class="p">));</span>
<span class="w"> </span><span class="n">fc2</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">register_module</span><span class="p">(</span><span class="s">"fc2"</span><span class="p">,</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">nn</span><span class="o">::</span><span class="n">Linear</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span><span class="w"> </span><span class="mi">32</span><span class="p">));</span>
<span class="w"> </span><span class="n">fc3</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">register_module</span><span class="p">(</span><span class="s">"fc3"</span><span class="p">,</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">nn</span><span class="o">::</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span><span class="w"> </span><span class="mi">10</span><span class="p">));</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="c1">// Implement the Net's algorithm.</span>
<span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span><span class="w"> </span><span class="n">forward</span><span class="p">(</span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span><span class="w"> </span><span class="n">x</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Use one of many tensor manipulation functions.</span>
<span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">relu</span><span class="p">(</span><span class="n">fc1</span><span class="o">-></span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">reshape</span><span class="p">({</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span><span class="w"> </span><span class="mi">784</span><span class="p">})));</span>
<span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="cm">/*p=*/</span><span class="mf">0.5</span><span class="p">,</span><span class="w"> </span><span class="cm">/*train=*/</span><span class="n">is_training</span><span class="p">());</span>
<span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">relu</span><span class="p">(</span><span class="n">fc2</span><span class="o">-></span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">));</span>
<span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">fc3</span><span class="o">-></span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">),</span><span class="w"> </span><span class="cm">/*dim=*/</span><span class="mi">1</span><span class="p">);</span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">x</span><span class="p">;</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="c1">// Use one of many "standard library" modules.</span>
<span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">nn</span><span class="o">::</span><span class="n">Linear</span><span class="w"> </span><span class="n">fc1</span><span class="p">{</span><span class="k">nullptr</span><span class="p">},</span><span class="w"> </span><span class="n">fc2</span><span class="p">{</span><span class="k">nullptr</span><span class="p">},</span><span class="w"> </span><span class="n">fc3</span><span class="p">{</span><span class="k">nullptr</span><span class="p">};</span>
<span class="p">};</span>
<span class="kt">int</span><span class="w"> </span><span class="nf">main</span><span class="p">()</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Create a new Net.</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">net</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">make_shared</span><span class="o"><</span><span class="n">Net</span><span class="o">></span><span class="p">();</span>
<span class="w"> </span><span class="c1">// Create a multi-threaded data loader for the MNIST dataset.</span>
<span class="w"> </span><span class="k">auto</span><span class="w"> </span><span class="n">data_loader</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">data</span><span class="o">::</span><span class="n">make_data_loader</span><span class="p">(</span>
<span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">data</span><span class="o">::</span><span class="n">datasets</span><span class="o">::</span><span class="n">MNIST</span><span class="p">(</span><span class="s">"./data"</span><span class="p">).</span><span class="n">map</span><span class="p">(</span>
<span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">data</span><span class="o">::</span><span class="n">transforms</span><span class="o">::</span><span class="n">Stack</span><span class="o"><></span><span class="p">()),</span>
<span class="w"> </span><span class="cm">/*batch_size=*/</span><span class="mi">64</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Instantiate an SGD optimization algorithm to update our Net's parameters.</span>
<span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">optim</span><span class="o">::</span><span class="n">SGD</span><span class="w"> </span><span class="n">optimizer</span><span class="p">(</span><span class="n">net</span><span class="o">-></span><span class="n">parameters</span><span class="p">(),</span><span class="w"> </span><span class="cm">/*lr=*/</span><span class="mf">0.01</span><span class="p">);</span>
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="kt">size_t</span><span class="w"> </span><span class="n">epoch</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">1</span><span class="p">;</span><span class="w"> </span><span class="n">epoch</span><span class="w"> </span><span class="o"><=</span><span class="w"> </span><span class="mi">10</span><span class="p">;</span><span class="w"> </span><span class="o">++</span><span class="n">epoch</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="kt">size_t</span><span class="w"> </span><span class="n">batch_index</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="mi">0</span><span class="p">;</span>
<span class="w"> </span><span class="c1">// Iterate the data loader to yield batches from the dataset.</span>
<span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="p">(</span><span class="k">auto</span><span class="o">&</span><span class="w"> </span><span class="n">batch</span><span class="w"> </span><span class="o">:</span><span class="w"> </span><span class="o">*</span><span class="n">data_loader</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="c1">// Reset gradients.</span>
<span class="w"> </span><span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">();</span>
<span class="w"> </span><span class="c1">// Execute the model on the input data.</span>
<span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span><span class="w"> </span><span class="n">prediction</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">net</span><span class="o">-></span><span class="n">forward</span><span class="p">(</span><span class="n">batch</span><span class="p">.</span><span class="n">data</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Compute a loss value to judge the prediction of our model.</span>
<span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span><span class="w"> </span><span class="n">loss</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">nll_loss</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span><span class="w"> </span><span class="n">batch</span><span class="p">.</span><span class="n">target</span><span class="p">);</span>
<span class="w"> </span><span class="c1">// Compute gradients of the loss w.r.t. the parameters of our model.</span>
<span class="w"> </span><span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">();</span>
<span class="w"> </span><span class="c1">// Update the parameters based on the calculated gradients.</span>
<span class="w"> </span><span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">();</span>
<span class="w"> </span><span class="c1">// Output the loss and checkpoint every 100 batches.</span>
<span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="p">(</span><span class="o">++</span><span class="n">batch_index</span><span class="w"> </span><span class="o">%</span><span class="w"> </span><span class="mi">100</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="mi">0</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">cout</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"Epoch: "</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">epoch</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">" | Batch: "</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">batch_index</span>
<span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">" | Loss: "</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">()</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">endl</span><span class="p">;</span>
<span class="w"> </span><span class="c1">// Serialize your model periodically as a checkpoint.</span>
<span class="w"> </span><span class="n">torch</span><span class="o">::</span><span class="n">save</span><span class="p">(</span><span class="n">net</span><span class="p">,</span><span class="w"> </span><span class="s">"net.pt"</span><span class="p">);</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="p">}</span>
<span class="w"> </span><span class="p">}</span>
<span class="p">}</span>
</pre></div>
</div>
<p>To see more complete examples of using the PyTorch C++ frontend, see <a class="reference external" href="https://github.com/pytorch/examples/tree/master/cpp">the example repository</a>.</p>
</div>
<div class="section" id="philosophy">
<h2>Philosophy<a class="headerlink" href="#philosophy" title="Permalink to this heading">¶</a></h2>
<p>PyTorch’s C++ frontend was designed with the idea that the Python frontend is
great, and should be used when possible; but in some settings, performance and
portability requirements make the use of the Python interpreter infeasible. For
example, Python is a poor choice for low latency, high performance or
multithreaded environments, such as video games or production servers. The
goal of the C++ frontend is to address these use cases, while not sacrificing
the user experience of the Python frontend.</p>
<p>As such, the C++ frontend has been written with a few philosophical goals in mind:</p>
<ul class="simple">
<li><p><strong>Closely model the Python frontend in its design</strong>, naming, conventions and
functionality. While there may be occasional differences between the two
frontends (e.g., where we have dropped deprecated features or fixed “warts”
in the Python frontend), we guarantee that the effort in porting a Python model
to C++ should lie exclusively in <strong>translating language features</strong>,
not modifying functionality or behavior.</p></li>
<li><p><strong>Prioritize flexibility and user-friendliness over micro-optimization.</strong>
In C++, you can often get optimal code, but at the cost of an extremely
unfriendly user experience. Flexibility and dynamism is at the heart of
PyTorch, and the C++ frontend seeks to preserve this experience, in some
cases sacrificing performance (or “hiding” performance knobs) to keep APIs
simple and explicable. We want researchers who don’t write C++ for a living
to be able to use our APIs.</p></li>
</ul>
<p>A word of warning: Python is not necessarily slower than
C++! The Python frontend calls into C++ for almost anything computationally expensive
(especially any kind of numeric operation), and these operations will take up
the bulk of time spent in a program. If you would prefer to write Python,
and can afford to write Python, we recommend using the Python interface to
PyTorch. However, if you would prefer to write C++, or need to write C++
(because of multithreading, latency or deployment requirements), the
C++ frontend to PyTorch provides an API that is approximately as convenient,
flexible, friendly and intuitive as its Python counterpart. The two frontends
serve different use cases, work hand in hand, and neither is meant to
unconditionally replace the other.</p>
</div>
<div class="section" id="installation">
<h2>Installation<a class="headerlink" href="#installation" title="Permalink to this heading">¶</a></h2>
<p>Instructions on how to install the C++ frontend library distribution, including
an example for how to build a minimal application depending on LibTorch, may be
found by following <a class="reference external" href="https://pytorch.org/cppdocs/installing.html">this</a> link.</p>
</div>
</div>
</article>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="api/library_root.html" class="btn btn-neutral float-right" title="Library API" accesskey="n" rel="next">Next <img src="_static/images/chevron-right-orange.svg" class="next-page"></a>
<a href="installing.html" class="btn btn-neutral" title="Installing C++ Distributions of PyTorch" accesskey="p" rel="prev"><img src="_static/images/chevron-right-orange.svg" class="previous-page"> Previous</a>
</div>
<hr>
<div role="contentinfo">
<p>
© Copyright 2024, PyTorch Contributors.
</p>
</div>
<div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</div>
</footer>
</div>
</div>
<div class="pytorch-content-right" id="pytorch-content-right">
<div class="pytorch-right-menu" id="pytorch-right-menu">
<div class="pytorch-side-scroll" id="pytorch-side-scroll-right">
<ul>
<li><a class="reference internal" href="#">The C++ Frontend</a><ul>
<li><a class="reference internal" href="#description">Description</a></li>
<li><a class="reference internal" href="#end-to-end-example">End-to-end example</a></li>
<li><a class="reference internal" href="#philosophy">Philosophy</a></li>
<li><a class="reference internal" href="#installation">Installation</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
<script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
<script src="_static/jquery.js"></script>
<script src="_static/underscore.js"></script>
<script src="_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="_static/doctools.js"></script>
<script src="_static/sphinx_highlight.js"></script>
<script src="_static/collapsible-lists/js/CollapsibleLists.compressed.js"></script>
<script src="_static/collapsible-lists/js/apply-collapsible-lists.js"></script>
<script type="text/javascript" src="_static/js/vendor/popper.min.js"></script>
<script type="text/javascript" src="_static/js/vendor/bootstrap.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/list.js/1.5.0/list.min.js"></script>
<script type="text/javascript" src="_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
<!-- Begin Footer -->
<div class="container-fluid docs-tutorials-resources" id="docs-tutorials-resources">
<div class="container">
<div class="row">
<div class="col-md-4 text-center">
<h2>Docs</h2>
<p>Access comprehensive developer documentation for PyTorch</p>
<a class="with-right-arrow" href="https://pytorch.org/docs/stable/index.html">View Docs</a>
</div>
<div class="col-md-4 text-center">
<h2>Tutorials</h2>
<p>Get in-depth tutorials for beginners and advanced developers</p>
<a class="with-right-arrow" href="https://pytorch.org/tutorials">View Tutorials</a>
</div>
<div class="col-md-4 text-center">
<h2>Resources</h2>
<p>Find development resources and get your questions answered</p>
<a class="with-right-arrow" href="https://pytorch.org/resources">View Resources</a>
</div>
</div>
</div>
</div>
<footer class="site-footer">
<div class="container footer-container">
<div class="footer-logo-wrapper">
<a href="https://pytorch.org/" class="footer-logo"></a>
</div>
<div class="footer-links-wrapper">
<div class="footer-links-col">
<ul>
<li class="list-title"><a href="https://pytorch.org/">PyTorch</a></li>
<li><a href="https://pytorch.org/get-started">Get Started</a></li>
<li><a href="https://pytorch.org/features">Features</a></li>
<li><a href="https://pytorch.org/ecosystem">Ecosystem</a></li>
<li><a href="https://pytorch.org/blog/">Blog</a></li>
<li><a href="https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md">Contributing</a></li>
</ul>
</div>
<div class="footer-links-col">
<ul>
<li class="list-title"><a href="https://pytorch.org/resources">Resources</a></li>
<li><a href="https://pytorch.org/tutorials">Tutorials</a></li>
<li><a href="https://pytorch.org/docs/stable/index.html">Docs</a></li>
<li><a href="https://discuss.pytorch.org" target="_blank">Discuss</a></li>
<li><a href="https://github.com/pytorch/pytorch/issues" target="_blank">Github Issues</a></li>
<li><a href="https://pytorch.org/assets/brand-guidelines/PyTorch-Brand-Guidelines.pdf" target="_blank">Brand Guidelines</a></li>
</ul>
</div>
<div class="footer-links-col">
<ul>
<li class="list-title">Stay up to date</li>
<li><a href="https://www.facebook.com/pytorch" target="_blank">Facebook</a></li>
<li><a href="https://twitter.com/pytorch" target="_blank">Twitter</a></li>
<li><a href="https://www.youtube.com/pytorch" target="_blank">YouTube</a></li>
<li><a href="https://www.linkedin.com/company/pytorch" target="_blank">LinkedIn</a></li>
</ul>
</div>
<div class="footer-links-col">
<ul>
<li class="list-title">PyTorch Podcasts</li>
<li><a href="https://open.spotify.com/show/6UzHKeiy368jKfQMKKvJY5" target="_blank">Spotify</a></li>
<li><a href="https://podcasts.apple.com/us/podcast/pytorch-developer-podcast/id1566080008" target="_blank">Apple</a></li>
<li><a href="https://www.google.com/podcasts?feed=aHR0cHM6Ly9mZWVkcy5zaW1wbGVjYXN0LmNvbS9PQjVGa0lsOA%3D%3D" target="_blank">Google</a></li>
<li><a href="https://music.amazon.com/podcasts/7a4e6f0e-26c2-49e9-a478-41bd244197d0/PyTorch-Developer-Podcast?" target="_blank">Amazon</a></li>
</ul>
</div>
</div>
<div class="privacy-policy">
<ul>
<li class="privacy-policy-links"><a href="https://www.linuxfoundation.org/terms/" target="_blank">Terms</a></li>
<li class="privacy-policy-links">|</li>
<li class="privacy-policy-links"><a href="https://www.linuxfoundation.org/privacy-policy/" target="_blank">Privacy</a></li>
</ul>
</div>
<div class="copyright">
<p>© Copyright The Linux Foundation. The PyTorch Foundation is a project of The Linux Foundation.
For web site terms of use, trademark policy and other policies applicable to The PyTorch Foundation please see
<a href="https://www.linuxfoundation.org/policies/">www.linuxfoundation.org/policies/</a>. The PyTorch Foundation supports the PyTorch open source
project, which has been established as PyTorch Project a Series of LF Projects, LLC. For policies applicable to the PyTorch Project a Series of LF Projects, LLC,
please see <a href="https://www.lfprojects.org/policies/">www.lfprojects.org/policies/</a>.</p>
</div>
</div>
</footer>
<div class="cookie-banner-wrapper">
<div class="container">
<p class="gdpr-notice">To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: <a href="https://www.facebook.com/policies/cookies/">Cookies Policy</a>.</p>
<img class="close-button" src="_static/images/pytorch-x.svg">
</div>
</div>
<!-- End Footer -->
<!-- Begin Mobile Menu -->
<div class="mobile-main-menu">
<div class="container-fluid">
<div class="container">
<div class="mobile-main-menu-header-container">
<a class="header-logo" href="https://pytorch.org/" aria-label="PyTorch"></a>
<a class="main-menu-close-button" href="#" data-behavior="close-mobile-menu"></a>
</div>
</div>
</div>
<div class="mobile-main-menu-links-container">
<div class="main-menu">
<ul>
<li class="resources-mobile-menu-title">
<a>Learn</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/get-started">Get Started</a>
</li>
<li>
<a href="https://pytorch.org/tutorials">Tutorials</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/beginner/basics/intro.html">Learn the Basics</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/recipes/recipes_index.html">PyTorch Recipes</a>
</li>
<li>
<a href="https://pytorch.org/tutorials/beginner/introyt.html">Introduction to PyTorch - YouTube Series</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Ecosystem</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/ecosystem">Tools</a>
</li>
<li>
<a href="https://pytorch.org/#community-module">Community</a>
</li>
<li>
<a href="https://discuss.pytorch.org/">Forums</a>
</li>
<li>
<a href="https://pytorch.org/resources">Developer Resources</a>
</li>
<li>
<a href="https://pytorch.org/ecosystem/contributor-awards-2023">Contributor Awards - 2023</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Edge</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/edge">About PyTorch Edge</a>
</li>
<li>
<a href="https://pytorch.org/executorch-overview">ExecuTorch</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Docs</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/docs/stable/index.html">PyTorch</a>
</li>
<li>
<a href="https://pytorch.org/pytorch-domains">PyTorch Domains</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>Blog & News</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/blog/">PyTorch Blog</a>
</li>
<li>
<a href="https://pytorch.org/community-blog">Community Blog</a>
</li>
<li>
<a href="https://pytorch.org/videos">Videos</a>
</li>
<li>
<a href="https://pytorch.org/community-stories">Community Stories</a>
</li>
<li>
<a href="https://pytorch.org/events">Events</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
<a>About</a>
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/foundation">PyTorch Foundation</a>
</li>
<li>
<a href="https://pytorch.org/governing-board">Governing Board</a>
</li>
</ul>
</ul>
</div>
</div>
</div>
<!-- End Mobile Menu -->
<script type="text/javascript" src="_static/js/vendor/anchor.min.js"></script>
<script type="text/javascript">
$(document).ready(function() {
mobileMenu.bind();
mobileTOC.bind();
pytorchAnchors.bind();
sideMenus.bind();
scrollToAnchor.bind();
highlightNavigation.bind();
mainMenuDropdown.bind();
filterTags.bind();
// Add class to links that have code blocks, since we cannot create links in code blocks
$("article.pytorch-article a span.pre").each(function(e) {
$(this).closest("a").addClass("has-code");
});
})
</script>
</body>
</html>