From ac9b5ac8f3de093cdaa95076a927c4830986df46 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Tue, 29 Aug 2023 14:02:19 -0700 Subject: [PATCH] first addition of files --- .gitignore | 72 +++++ Cargo.lock | 296 +++++++++++++++++++++ Cargo.toml | 22 ++ pyproject.toml | 16 ++ src/lib.rs | 706 +++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 1112 insertions(+) create mode 100755 .gitignore create mode 100755 Cargo.lock create mode 100755 Cargo.toml create mode 100755 pyproject.toml create mode 100755 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..af3ca5e --- /dev/null +++ b/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100755 index 0000000..9074f6b --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,296 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cotengrust" +version = "0.1.0" +dependencies = [ + "bit-set", + "pyo3", + "rustc-hash", +] + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "libc" +version = "0.2.147" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" + +[[package]] +name = "lock_api" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "proc-macro2" +version = "1.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" + +[[package]] +name = "unicode-ident" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/Cargo.toml b/Cargo.toml new file mode 100755 index 0000000..996bb7a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "cotengrust" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "cotengrust" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = "0.19" + +[dependencies.bit-set] +version = "0.5" + +[dependencies.rustc-hash] +version = "1.1" + +[profile.release] +overflow-checks = true +lto = "fat" diff --git a/pyproject.toml b/pyproject.toml new file mode 100755 index 0000000..a5ae6bf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["maturin>=0.15,<0.16"] +build-backend = "maturin" + +[project] +name = "cotengrust" +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] + + +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/src/lib.rs b/src/lib.rs new file mode 100755 index 0000000..4926267 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,706 @@ +// use bit_set::BitSet; +use pyo3::prelude::*; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::collections::{BTreeSet, BinaryHeap}; + +use FxHashMap as Dict; + +type Ix = u16; +type Count = u8; +type Legs = Vec<(Ix, Count)>; +type Node = u32; +// type Subgraph = BitSet; +type Score = u128; +type GreedyScore = i128; +// type BitPath = Vec<(Subgraph, Subgraph)>; +type SSAPath = Vec>; +// type SubContraction = (Legs, Score, BitPath); + +struct HypergraphProcessor { + nodes: Dict, + edges: Dict>, + appearances: Vec, + sizes: Vec, + ssa: Node, + ssa_path: SSAPath, +} + +fn contract_legs_(ilegs: Legs, jlegs: Legs, appearances: &Vec) -> Legs { + let mut ip = 0; + let mut jp = 0; + let ni = ilegs.len(); + let nj = jlegs.len(); + let mut new_legs: Legs = Vec::new(); + + loop { + if ip == ni { + new_legs.extend(jlegs[jp..].iter()); + break; + } + if jp == nj { + new_legs.extend(ilegs[ip..].iter()); + break; + } + + let (ix, ic) = ilegs[ip]; + let (jx, jc) = jlegs[jp]; + + if ix < jx { + // index only appears in ilegs + new_legs.push((ix, ic)); + ip += 1; + } else if ix > jx { + // index only appears in jlegs + new_legs.push((jx, jc)); + jp += 1; + } else { + // index appears in both + let new_count = ic + jc; + if new_count != appearances[ix as usize] { + // not last appearance -> kept index contributes to new size + new_legs.push((ix, new_count)); + } + ip += 1; + jp += 1; + } + } + new_legs +} + +fn compute_size(legs: &Legs, sizes: &Vec) -> Score { + let mut d = 1; + for (ix, _) in legs { + d *= sizes[*ix as usize]; + } + d +} + +impl HypergraphProcessor { + fn new( + inputs: Vec>, + output: Vec, + size_dict: Dict, + ) -> HypergraphProcessor { + let mut nodes: Dict = Dict::default(); + let mut edges: Dict> = Dict::default(); + let mut indmap: Dict = Dict::default(); + let mut sizes: Vec = Vec::with_capacity(size_dict.len()); + let mut appearances: Vec = Vec::with_capacity(size_dict.len()); + let mut c: Ix = 0; + + for (i, term) in inputs.into_iter().enumerate() { + let mut legs = Vec::new(); + for ind in term { + match indmap.get(&ind) { + None => { + // index not parsed yet + indmap.insert(ind, c); + edges.insert(c, vec![i as Node]); + appearances.push(1); + sizes.push(size_dict[&ind]); + legs.push((c, 1)); + c += 1; + } + Some(&ix) => { + // index already present + appearances[ix as usize] += 1; + edges.get_mut(&ix).unwrap().push(i as Node); + legs.push((ix, 1)); + } + }; + } + legs.sort(); + nodes.insert(i as Node, legs); + } + output.into_iter().for_each(|ind| { + appearances[indmap[&ind] as usize] += 1; + }); + + let ssa = nodes.len() as Node; + let ssa_path: SSAPath = Vec::new(); + + HypergraphProcessor { + nodes, + edges, + appearances, + sizes, + ssa, + ssa_path, + } + } + + fn neighbors(&self, i: Node) -> Vec { + let mut js = Vec::new(); + for (ix, _) in self.nodes[&i].iter() { + for j in self.edges[&ix].iter() { + if *j != i { + js.push(*j); + } + } + } + js + } + + fn remove_ix(&mut self, ix: Ix) { + for j in self.edges.remove(&ix).unwrap() { + self.nodes.get_mut(&j).unwrap().retain(|(k, _)| *k != ix); + } + } + + fn pop_node(&mut self, i: Node) -> Legs { + let legs = self.nodes.remove(&i).unwrap(); + for (ix, _) in legs.iter() { + let nodes = self.edges.get_mut(&ix).unwrap(); + if nodes.len() == 1 { + self.edges.remove(&ix); + } else { + nodes.retain(|&j| j != i); + } + } + legs + } + + fn add_node(&mut self, legs: Legs) -> Node { + let i = self.ssa; + self.ssa += 1; + for (ix, _) in &legs { + self.edges + .entry(*ix) + .and_modify(|nodes| nodes.push(i)) + .or_insert(vec![i]); + } + self.nodes.insert(i, legs); + i + } + + fn contract(&mut self, i: Node, j: Node) -> Node { + let ilegs = self.pop_node(i); + let jlegs = self.pop_node(j); + let new_legs = contract_legs_(ilegs, jlegs, &self.appearances); + let k = self.add_node(new_legs); + self.ssa_path.push(vec![i, j]); + k + } + + fn simplify_batch(&mut self) { + let mut ix_to_remove = Vec::new(); + let nterms = self.nodes.len(); + for (ix, ix_nodes) in self.edges.iter() { + if ix_nodes.len() >= nterms { + ix_to_remove.push(*ix); + } + } + for ix in ix_to_remove { + self.remove_ix(ix); + } + } + + fn simplify_single_terms(&mut self) { + for (i, legs) in self.nodes.clone().into_iter() { + if legs + .iter() + .any(|&(ix, c)| c == self.appearances[ix as usize]) + { + let mut legs_reduced = self.pop_node(i); + legs_reduced.retain(|&(ix, c)| c != self.appearances[ix as usize]); + self.add_node(legs_reduced); + self.ssa_path.push(vec![i]); + } + } + } + + fn simplify_scalars(&mut self) { + let mut scalars = Vec::new(); + for (i, legs) in self.nodes.iter() { + if legs.len() == 0 { + scalars.push(*i); + } + } + if scalars.len() > 0 { + for &i in &scalars { + self.pop_node(i); + } + let (res, con) = match self.nodes.iter().min_by_key(|&(_, legs)| legs.len()) { + Some((&j, _)) => { + let res = self.pop_node(j); + let con: Vec = scalars.into_iter().chain(vec![j].into_iter()).collect(); + (res, con) + } + None => { + let res = Vec::new(); + (res, scalars) + } + }; + self.add_node(res); + self.ssa_path.push(con); + } + } + + fn simplify_hadamard(&mut self) { + let mut groups: Dict, Vec> = Dict::default(); + let mut hadamards: BTreeSet> = BTreeSet::default(); + for (i, legs) in self.nodes.iter() { + let key: BTreeSet = legs.iter().map(|&(ix, _)| ix).collect(); + match groups.get_mut(&key) { + Some(group) => { + hadamards.insert(key); + group.push(*i); + } + None => { + groups.insert(key, vec![*i]); + } + } + } + for key in hadamards.into_iter() { + let mut group = groups.remove(&key).unwrap(); + while group.len() > 1 { + let i = group.pop().unwrap(); + let j = group.pop().unwrap(); + let k = self.contract(i, j); + group.push(k); + } + } + } + + fn simplify(&mut self) { + self.simplify_batch(); + let mut should_run = true; + while should_run { + self.simplify_single_terms(); + self.simplify_scalars(); + let ssa_before = self.ssa; + self.simplify_hadamard(); + should_run = ssa_before != self.ssa; + } + } + + fn subgraphs(&self) -> Vec> { + let mut remaining: BTreeSet = BTreeSet::default(); + self.nodes.keys().for_each(|i| { + remaining.insert(*i); + }); + let mut groups: Vec> = Vec::new(); + while remaining.len() > 0 { + let i = remaining.pop_first().unwrap(); + let mut queue: Vec = vec![i]; + let mut group: BTreeSet = vec![i].into_iter().collect(); + while queue.len() > 0 { + let i = queue.pop().unwrap(); + for j in self.neighbors(i) { + if !group.contains(&j) { + group.insert(j); + queue.push(j); + } + } + } + group.iter().for_each(|i| { + remaining.remove(i); + }); + groups.push(group.into_iter().collect()); + } + groups + } + + fn optimize_greedy(&mut self) { + // cache all nodes sizes as we go + let mut node_sizes: Dict = Dict::default(); + self.nodes.iter().for_each(|(&i, legs)| { + node_sizes.insert(i, compute_size(&legs, &self.sizes)); + }); + + let mut queue: BinaryHeap<(GreedyScore, usize)> = BinaryHeap::default(); + let mut contractions: Dict = Dict::default(); + let mut checked: FxHashSet<(Node, Node)> = FxHashSet::default(); + let mut c: usize = 0; + + // get the initial candidate contractions + for ix_nodes in self.edges.values() { + for ip in 0..ix_nodes.len() { + let i = ix_nodes[ip]; + let isize = node_sizes[&i]; + for jp in (ip + 1)..ix_nodes.len() { + let j = ix_nodes[jp]; + let jsize = node_sizes[&j]; + let klegs = contract_legs_( + self.nodes[&i].clone(), + self.nodes[&j].clone(), + &self.appearances, + ); + let ksize = compute_size(&klegs, &self.sizes); + let score = (isize as GreedyScore) - ((jsize + ksize) as GreedyScore); + queue.push((-score, c)); + contractions.insert(c, (i, j, ksize, klegs)); + c += 1; + } + } + } + + // greedily contract remaining + while let Some((_, c0)) = queue.pop() { + let (i, j, ksize, klegs) = contractions.remove(&c0).unwrap(); + if !self.nodes.contains_key(&i) || !self.nodes.contains_key(&j) { + // one of the nodes has been removed -> skip + continue; + } + + self.pop_node(i); + self.pop_node(j); + let k = self.add_node(klegs.clone()); + self.ssa_path.push(vec![i, j]); + node_sizes.insert(k, ksize); + + for l in self.neighbors(k) { + let key = if k < l { (k, l) } else { (l, k) }; + if checked.contains(&key) { + continue; + } else { + checked.insert(key); + } + let llegs = self.nodes[&l].clone(); + let lsize = node_sizes[&l]; + let mlegs = + contract_legs_(klegs.clone(), llegs, &self.appearances); + let msize = compute_size(&mlegs, &self.sizes); + let score = (msize as GreedyScore) - ((ksize + lsize) as GreedyScore); + queue.push((-score, c)); + contractions.insert(c, (k, l, msize, mlegs)); + c += 1; + } + } + } +} + +#[pyfunction] +#[pyo3()] +fn test_simplify( + inputs: Vec>, + output: Vec, + size_dict: Dict, +) -> SSAPath { + let mut hgp = HypergraphProcessor::new(inputs, output, size_dict); + hgp.simplify(); + hgp.ssa_path +} + +#[pyfunction] +#[pyo3()] +fn test_subgraphs( + inputs: Vec>, + output: Vec, + size_dict: Dict, +) -> Vec> { + let hgp = HypergraphProcessor::new(inputs, output, size_dict); + hgp.subgraphs() +} + + +#[pyfunction] +#[pyo3()] +fn test_greedy( + inputs: Vec>, + output: Vec, + size_dict: Dict, +) -> Vec> { + let mut hgp = HypergraphProcessor::new(inputs, output, size_dict); + hgp.simplify(); + hgp.optimize_greedy(); + hgp.ssa_path +} + +// fn single_el_bitset(x: usize, n: usize) -> BitSet { +// let mut a: BitSet = BitSet::with_capacity(n); +// a.insert(x); +// a +// } + +// fn compute_con_cost_flops( +// temp_legs: Legs, +// appearances: &Vec, +// sizes: &Vec, +// iscore: &Score, +// jscore: &Score, +// _factor: Score, +// ) -> (Legs, Score) { +// // remove indices that have reached final appearance +// // and compute cost and size of local contraction +// let mut new_legs: Legs = Legs::new(); +// let mut cost: Score = 1; +// for (ix, ix_count) in temp_legs.into_iter() { +// // all involved indices contribute to the cost +// let d = sizes[ix as usize]; +// cost *= d; +// if ix_count != appearances[ix as usize] { +// // not last appearance -> kept index contributes to new size +// new_legs.push((ix, ix_count)); +// } +// } +// let new_score = iscore + jscore + cost; +// (new_legs, new_score) +// } + +// fn compute_con_cost_size( +// temp_legs: Legs, +// appearances: &Vec, +// sizes: &Vec, +// iscore: &Score, +// jscore: &Score, +// _factor: Score, +// ) -> (Legs, Score) { +// // remove indices that have reached final appearance +// // and compute cost and size of local contraction +// let mut new_legs: Legs = Legs::new(); +// let mut size: Score = 1; +// for (ix, ix_count) in temp_legs.into_iter() { +// if ix_count != appearances[ix as usize] { +// // not last appearance -> kept index contributes to new size +// new_legs.push((ix, ix_count)); +// size *= sizes[ix as usize]; +// } +// } +// let new_score = *iscore.max(jscore).max(&size); +// (new_legs, new_score) +// } + +// fn compute_con_cost_write( +// temp_legs: Legs, +// appearances: &Vec, +// sizes: &Vec, +// iscore: &Score, +// jscore: &Score, +// _factor: Score, +// ) -> (Legs, Score) { +// // remove indices that have reached final appearance +// // and compute cost and size of local contraction +// let mut new_legs: Legs = Legs::new(); +// let mut size: Score = 1; +// for (ix, ix_count) in temp_legs.into_iter() { +// if ix_count != appearances[ix as usize] { +// // not last appearance -> kept index contributes to new size +// new_legs.push((ix, ix_count)); +// size *= sizes[ix as usize]; +// } +// } +// let new_score = iscore + jscore + size; +// (new_legs, new_score) +// } + +// fn compute_con_cost_combo( +// temp_legs: Legs, +// appearances: &Vec, +// sizes: &Vec, +// iscore: &Score, +// jscore: &Score, +// factor: Score, +// ) -> (Legs, Score) { +// // remove indices that have reached final appearance +// // and compute cost and size of local contraction +// let mut new_legs: Legs = Legs::new(); +// let mut size: Score = 1; +// let mut cost: Score = 1; +// for (ix, ix_count) in temp_legs.into_iter() { +// // all involved indices contribute to the cost +// let d = sizes[ix as usize]; +// cost *= d; +// if ix_count != appearances[ix as usize] { +// // not last appearance -> kept index contributes to new size +// new_legs.push((ix, ix_count)); +// size *= d; +// } +// } +// // the score just for this contraction +// let new_local_score = cost + factor * size; + +// // the total score including history +// let new_score = iscore + jscore + new_local_score; + +// (new_legs, new_score) +// } + +// fn convert_bitpath_to_ssapath(bitpath: &BitPath, nterms: usize) -> SSAPath { +// let mut subgraph_to_ssa = Dict::default(); +// let mut ssa = 0; +// let mut ssa_path = Vec::new(); +// // create ssa leaves +// for i in 0..nterms { +// subgraph_to_ssa.insert(single_el_bitset(i, nterms), ssa); +// ssa += 1; +// } +// // process the path, creating parent ssa ids as we go +// for (isubgraph, jsubgraph) in bitpath.into_iter() { +// ssa_path.push((subgraph_to_ssa[isubgraph], subgraph_to_ssa[jsubgraph])); +// subgraph_to_ssa.insert(isubgraph.union(&jsubgraph).collect(), ssa); +// ssa += 1; +// } +// ssa_path +// } + +// #[pyfunction] +// #[pyo3()] +// fn optimal( +// inputs: Vec>, +// output: Vec, +// size_dict: Dict, +// minimize: Option, +// factor: Option, +// cost_cap: Option, +// ) -> SSAPath { +// let minimize = minimize.unwrap_or("flops".to_string()); +// let factor = factor.unwrap_or(64); +// let compute_cost = match minimize.as_str() { +// "flops" => compute_con_cost_flops, +// "size" => compute_con_cost_size, +// "write" => compute_con_cost_write, +// "combo" => compute_con_cost_combo, +// _ => panic!( +// "minimize must be one of 'flops', 'size', 'write', or 'combo', got {}", +// minimize +// ), +// }; + +// let nterms = inputs.len(); +// let mut indmap: Dict = Dict::default(); +// let mut sizes: Vec = Vec::with_capacity(size_dict.len()); +// let mut appearances: Vec = Vec::with_capacity(size_dict.len()); +// let mut c: Ix = 0; +// // storage for each possible contraction to reach subgraph of size k +// let mut contractions: Vec> = vec![Dict::default(); nterms + 1]; +// // intermediate storage for the entries we are expanding +// let mut contractions_m_temp: Vec<(Subgraph, SubContraction)> = Vec::new(); +// // need to keep these separately +// let mut best_scores: Dict = Dict::default(); + +// // map the string indices to integers, forming the int input terms as well +// for (j, term) in inputs.into_iter().enumerate() { +// let mut legs: Legs = Vec::new(); +// for ind in term { +// match indmap.get(&ind) { +// Some(&cex) => { +// // index already present +// appearances[cex as usize] += 1; +// legs.push((cex, 1)); +// } +// None => { +// // index not parsed yet +// indmap.insert(ind.clone(), c); +// sizes.push(size_dict[&ind]); +// appearances.push(1); +// legs.push((c, 1)); +// c += 1; +// } +// }; +// } +// legs.sort(); +// contractions[1].insert(single_el_bitset(j, nterms), (legs, 0, Vec::new())); +// } +// // parse the output -> just needed for appearances sake +// output.into_iter().for_each(|ind| { +// appearances[indmap[&ind] as usize] += 1; +// }); + +// // let mut inds_to_remove: Vec = Vec::new(); +// let mut ip: usize; +// let mut jp: usize; +// let mut outer: bool; + +// let mut cost_cap = cost_cap.unwrap_or(1); +// while contractions[nterms].len() == 0 { +// // try building subgraphs of size m +// for m in 2..=nterms { +// // out of bipartitions of size (k, m - k) +// for k in 1..=m / 2 { +// for (isubgraph, (ilegs, iscore, ipath)) in contractions[k].iter() { +// for (jsubgraph, (jlegs, jscore, jpath)) in contractions[m - k].iter() { +// // filter invalid combinations first +// if !isubgraph.is_disjoint(&jsubgraph) || { +// (k == m - k) && isubgraph.gt(&jsubgraph) +// } { +// // subgraphs overlap -> not valid, or +// // equal subgraph size -> only process sorted pairs +// continue; +// } + +// let mut temp_legs: Legs = Vec::new(); +// ip = 0; +// jp = 0; +// outer = true; +// while ip < ilegs.len() && jp < jlegs.len() { +// if ilegs[ip].0 < jlegs[jp].0 { +// // index only appears in ilegs +// temp_legs.push(ilegs[ip]); +// ip += 1; +// } else if ilegs[ip].0 > jlegs[jp].0 { +// // index only appears in jlegs +// temp_legs.push(jlegs[jp]); +// jp += 1; +// } else { +// // index appears in both +// temp_legs.push((ilegs[ip].0, ilegs[ip].1 + jlegs[jp].1)); +// ip += 1; +// jp += 1; +// outer = false; +// } +// } +// if outer { +// // no shared indices -> outer product +// continue; +// } +// // add any remaining indices +// temp_legs.extend(ilegs[ip..].iter().chain(jlegs[jp..].iter())); + +// // compute candidate contraction result and score +// let (new_legs, new_score) = +// compute_cost(temp_legs, &appearances, &sizes, iscore, jscore, factor); + +// if new_score > cost_cap { +// // contraction not allowed yet due to cost +// continue; +// } + +// // check candidate against current best subgraph path +// let new_subgraph: Subgraph = isubgraph.union(&jsubgraph).collect(); + +// // because we have to do a delayed update of +// // contractions[m] for borrowing reasons, we check +// // against a non-delayed score lookup so we don't +// // overwrite best scores within the same iteration +// let new_best = match best_scores.get(&new_subgraph) { +// Some(current_score) => new_score < *current_score, +// None => true, +// }; +// if new_best { +// best_scores.insert(new_subgraph.clone(), new_score); +// // only need the path if updating +// let mut new_path = ipath.clone(); +// new_path.extend(jpath.clone()); +// new_path.push((isubgraph.clone(), jsubgraph.clone())); +// contractions_m_temp +// .push((new_subgraph, (new_legs, new_score, new_path))); +// } +// } +// } +// // move new contractions from temp into the main storage, there +// // might be contractions for the same subgraph in this, but +// // because we check eagerly best_scores above, later entries +// // are guaranteed to be better +// contractions_m_temp.drain(..).for_each(|(k, v)| { +// contractions[m].insert(k, v); +// }); +// } +// } +// cost_cap *= 2; +// } +// // can only ever be a single entry in contractions[nterms] -> the best one +// let (_, _, best_path) = contractions[nterms].values().next().unwrap(); + +// // need to convert the path to 'ssa' ids rather than bitset +// return convert_bitpath_to_ssapath(best_path, nterms); +// } + +/// A Python module implemented in Rust. +#[pymodule] +fn cotengrust(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(test_simplify, m)?)?; + m.add_function(wrap_pyfunction!(test_subgraphs, m)?)?; + m.add_function(wrap_pyfunction!(test_greedy, m)?)?; + Ok(()) +}