Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/numpy/core/einsumfunc.py : 6%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Implementation of optimized einsum.
4"""
5import itertools
6import operator
8from numpy.core.multiarray import c_einsum
9from numpy.core.numeric import asanyarray, tensordot
10from numpy.core.overrides import array_function_dispatch
12__all__ = ['einsum', 'einsum_path']
14einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
15einsum_symbols_set = set(einsum_symbols)
18def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
19 """
20 Computes the number of FLOPS in the contraction.
22 Parameters
23 ----------
24 idx_contraction : iterable
25 The indices involved in the contraction
26 inner : bool
27 Does this contraction require an inner product?
28 num_terms : int
29 The number of terms in a contraction
30 size_dictionary : dict
31 The size of each of the indices in idx_contraction
33 Returns
34 -------
35 flop_count : int
36 The total number of FLOPS required for the contraction.
38 Examples
39 --------
41 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
42 30
44 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
45 60
47 """
49 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
50 op_factor = max(1, num_terms - 1)
51 if inner:
52 op_factor += 1
54 return overall_size * op_factor
56def _compute_size_by_dict(indices, idx_dict):
57 """
58 Computes the product of the elements in indices based on the dictionary
59 idx_dict.
61 Parameters
62 ----------
63 indices : iterable
64 Indices to base the product on.
65 idx_dict : dictionary
66 Dictionary of index sizes
68 Returns
69 -------
70 ret : int
71 The resulting product.
73 Examples
74 --------
75 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
76 90
78 """
79 ret = 1
80 for i in indices:
81 ret *= idx_dict[i]
82 return ret
85def _find_contraction(positions, input_sets, output_set):
86 """
87 Finds the contraction for a given set of input and output sets.
89 Parameters
90 ----------
91 positions : iterable
92 Integer positions of terms used in the contraction.
93 input_sets : list
94 List of sets that represent the lhs side of the einsum subscript
95 output_set : set
96 Set that represents the rhs side of the overall einsum subscript
98 Returns
99 -------
100 new_result : set
101 The indices of the resulting contraction
102 remaining : list
103 List of sets that have not been contracted, the new set is appended to
104 the end of this list
105 idx_removed : set
106 Indices removed from the entire contraction
107 idx_contraction : set
108 The indices used in the current contraction
110 Examples
111 --------
113 # A simple dot product test case
114 >>> pos = (0, 1)
115 >>> isets = [set('ab'), set('bc')]
116 >>> oset = set('ac')
117 >>> _find_contraction(pos, isets, oset)
118 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
120 # A more complex case with additional terms in the contraction
121 >>> pos = (0, 2)
122 >>> isets = [set('abd'), set('ac'), set('bdc')]
123 >>> oset = set('ac')
124 >>> _find_contraction(pos, isets, oset)
125 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
126 """
128 idx_contract = set()
129 idx_remain = output_set.copy()
130 remaining = []
131 for ind, value in enumerate(input_sets):
132 if ind in positions:
133 idx_contract |= value
134 else:
135 remaining.append(value)
136 idx_remain |= value
138 new_result = idx_remain & idx_contract
139 idx_removed = (idx_contract - new_result)
140 remaining.append(new_result)
142 return (new_result, remaining, idx_removed, idx_contract)
145def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
146 """
147 Computes all possible pair contractions, sieves the results based
148 on ``memory_limit`` and returns the lowest cost path. This algorithm
149 scales factorial with respect to the elements in the list ``input_sets``.
151 Parameters
152 ----------
153 input_sets : list
154 List of sets that represent the lhs side of the einsum subscript
155 output_set : set
156 Set that represents the rhs side of the overall einsum subscript
157 idx_dict : dictionary
158 Dictionary of index sizes
159 memory_limit : int
160 The maximum number of elements in a temporary array
162 Returns
163 -------
164 path : list
165 The optimal contraction order within the memory limit constraint.
167 Examples
168 --------
169 >>> isets = [set('abd'), set('ac'), set('bdc')]
170 >>> oset = set()
171 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
172 >>> _optimal_path(isets, oset, idx_sizes, 5000)
173 [(0, 2), (0, 1)]
174 """
176 full_results = [(0, [], input_sets)]
177 for iteration in range(len(input_sets) - 1):
178 iter_results = []
180 # Compute all unique pairs
181 for curr in full_results:
182 cost, positions, remaining = curr
183 for con in itertools.combinations(range(len(input_sets) - iteration), 2):
185 # Find the contraction
186 cont = _find_contraction(con, remaining, output_set)
187 new_result, new_input_sets, idx_removed, idx_contract = cont
189 # Sieve the results based on memory_limit
190 new_size = _compute_size_by_dict(new_result, idx_dict)
191 if new_size > memory_limit:
192 continue
194 # Build (total_cost, positions, indices_remaining)
195 total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
196 new_pos = positions + [con]
197 iter_results.append((total_cost, new_pos, new_input_sets))
199 # Update combinatorial list, if we did not find anything return best
200 # path + remaining contractions
201 if iter_results:
202 full_results = iter_results
203 else:
204 path = min(full_results, key=lambda x: x[0])[1]
205 path += [tuple(range(len(input_sets) - iteration))]
206 return path
208 # If we have not found anything return single einsum contraction
209 if len(full_results) == 0:
210 return [tuple(range(len(input_sets)))]
212 path = min(full_results, key=lambda x: x[0])[1]
213 return path
215def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
216 """Compute the cost (removed size + flops) and resultant indices for
217 performing the contraction specified by ``positions``.
219 Parameters
220 ----------
221 positions : tuple of int
222 The locations of the proposed tensors to contract.
223 input_sets : list of sets
224 The indices found on each tensors.
225 output_set : set
226 The output indices of the expression.
227 idx_dict : dict
228 Mapping of each index to its size.
229 memory_limit : int
230 The total allowed size for an intermediary tensor.
231 path_cost : int
232 The contraction cost so far.
233 naive_cost : int
234 The cost of the unoptimized expression.
236 Returns
237 -------
238 cost : (int, int)
239 A tuple containing the size of any indices removed, and the flop cost.
240 positions : tuple of int
241 The locations of the proposed tensors to contract.
242 new_input_sets : list of sets
243 The resulting new list of indices if this proposed contraction is performed.
245 """
247 # Find the contraction
248 contract = _find_contraction(positions, input_sets, output_set)
249 idx_result, new_input_sets, idx_removed, idx_contract = contract
251 # Sieve the results based on memory_limit
252 new_size = _compute_size_by_dict(idx_result, idx_dict)
253 if new_size > memory_limit:
254 return None
256 # Build sort tuple
257 old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
258 removed_size = sum(old_sizes) - new_size
260 # NB: removed_size used to be just the size of any removed indices i.e.:
261 # helpers.compute_size_by_dict(idx_removed, idx_dict)
262 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
263 sort = (-removed_size, cost)
265 # Sieve based on total cost as well
266 if (path_cost + cost) > naive_cost:
267 return None
269 # Add contraction to possible choices
270 return [sort, positions, new_input_sets]
273def _update_other_results(results, best):
274 """Update the positions and provisional input_sets of ``results`` based on
275 performing the contraction result ``best``. Remove any involving the tensors
276 contracted.
278 Parameters
279 ----------
280 results : list
281 List of contraction results produced by ``_parse_possible_contraction``.
282 best : list
283 The best contraction of ``results`` i.e. the one that will be performed.
285 Returns
286 -------
287 mod_results : list
288 The list of modified results, updated with outcome of ``best`` contraction.
289 """
291 best_con = best[1]
292 bx, by = best_con
293 mod_results = []
295 for cost, (x, y), con_sets in results:
297 # Ignore results involving tensors just contracted
298 if x in best_con or y in best_con:
299 continue
301 # Update the input_sets
302 del con_sets[by - int(by > x) - int(by > y)]
303 del con_sets[bx - int(bx > x) - int(bx > y)]
304 con_sets.insert(-1, best[2][-1])
306 # Update the position indices
307 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
308 mod_results.append((cost, mod_con, con_sets))
310 return mod_results
312def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
313 """
314 Finds the path by contracting the best pair until the input list is
315 exhausted. The best pair is found by minimizing the tuple
316 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
317 matrix multiplication or inner product operations, then Hadamard like
318 operations, and finally outer operations. Outer products are limited by
319 ``memory_limit``. This algorithm scales cubically with respect to the
320 number of elements in the list ``input_sets``.
322 Parameters
323 ----------
324 input_sets : list
325 List of sets that represent the lhs side of the einsum subscript
326 output_set : set
327 Set that represents the rhs side of the overall einsum subscript
328 idx_dict : dictionary
329 Dictionary of index sizes
330 memory_limit_limit : int
331 The maximum number of elements in a temporary array
333 Returns
334 -------
335 path : list
336 The greedy contraction order within the memory limit constraint.
338 Examples
339 --------
340 >>> isets = [set('abd'), set('ac'), set('bdc')]
341 >>> oset = set()
342 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
343 >>> _greedy_path(isets, oset, idx_sizes, 5000)
344 [(0, 2), (0, 1)]
345 """
347 # Handle trivial cases that leaked through
348 if len(input_sets) == 1:
349 return [(0,)]
350 elif len(input_sets) == 2:
351 return [(0, 1)]
353 # Build up a naive cost
354 contract = _find_contraction(range(len(input_sets)), input_sets, output_set)
355 idx_result, new_input_sets, idx_removed, idx_contract = contract
356 naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
358 # Initially iterate over all pairs
359 comb_iter = itertools.combinations(range(len(input_sets)), 2)
360 known_contractions = []
362 path_cost = 0
363 path = []
365 for iteration in range(len(input_sets) - 1):
367 # Iterate over all pairs on first step, only previously found pairs on subsequent steps
368 for positions in comb_iter:
370 # Always initially ignore outer products
371 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
372 continue
374 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
375 naive_cost)
376 if result is not None:
377 known_contractions.append(result)
379 # If we do not have a inner contraction, rescan pairs including outer products
380 if len(known_contractions) == 0:
382 # Then check the outer products
383 for positions in itertools.combinations(range(len(input_sets)), 2):
384 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
385 path_cost, naive_cost)
386 if result is not None:
387 known_contractions.append(result)
389 # If we still did not find any remaining contractions, default back to einsum like behavior
390 if len(known_contractions) == 0:
391 path.append(tuple(range(len(input_sets))))
392 break
394 # Sort based on first index
395 best = min(known_contractions, key=lambda x: x[0])
397 # Now propagate as many unused contractions as possible to next iteration
398 known_contractions = _update_other_results(known_contractions, best)
400 # Next iteration only compute contractions with the new tensor
401 # All other contractions have been accounted for
402 input_sets = best[2]
403 new_tensor_pos = len(input_sets) - 1
404 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
406 # Update path and total cost
407 path.append(best[1])
408 path_cost += best[0][1]
410 return path
413def _can_dot(inputs, result, idx_removed):
414 """
415 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
417 Parameters
418 ----------
419 inputs : list of str
420 Specifies the subscripts for summation.
421 result : str
422 Resulting summation.
423 idx_removed : set
424 Indices that are removed in the summation
427 Returns
428 -------
429 type : bool
430 Returns true if BLAS should and can be used, else False
432 Notes
433 -----
434 If the operations is BLAS level 1 or 2 and is not already aligned
435 we default back to einsum as the memory movement to copy is more
436 costly than the operation itself.
439 Examples
440 --------
442 # Standard GEMM operation
443 >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
444 True
446 # Can use the standard BLAS, but requires odd data movement
447 >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
448 False
450 # DDOT where the memory is not aligned
451 >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
452 False
454 """
456 # All `dot` calls remove indices
457 if len(idx_removed) == 0:
458 return False
460 # BLAS can only handle two operands
461 if len(inputs) != 2:
462 return False
464 input_left, input_right = inputs
466 for c in set(input_left + input_right):
467 # can't deal with repeated indices on same input or more than 2 total
468 nl, nr = input_left.count(c), input_right.count(c)
469 if (nl > 1) or (nr > 1) or (nl + nr > 2):
470 return False
472 # can't do implicit summation or dimension collapse e.g.
473 # "ab,bc->c" (implicitly sum over 'a')
474 # "ab,ca->ca" (take diagonal of 'a')
475 if nl + nr - 1 == int(c in result):
476 return False
478 # Build a few temporaries
479 set_left = set(input_left)
480 set_right = set(input_right)
481 keep_left = set_left - idx_removed
482 keep_right = set_right - idx_removed
483 rs = len(idx_removed)
485 # At this point we are a DOT, GEMV, or GEMM operation
487 # Handle inner products
489 # DDOT with aligned data
490 if input_left == input_right:
491 return True
493 # DDOT without aligned data (better to use einsum)
494 if set_left == set_right:
495 return False
497 # Handle the 4 possible (aligned) GEMV or GEMM cases
499 # GEMM or GEMV no transpose
500 if input_left[-rs:] == input_right[:rs]:
501 return True
503 # GEMM or GEMV transpose both
504 if input_left[:rs] == input_right[-rs:]:
505 return True
507 # GEMM or GEMV transpose right
508 if input_left[-rs:] == input_right[-rs:]:
509 return True
511 # GEMM or GEMV transpose left
512 if input_left[:rs] == input_right[:rs]:
513 return True
515 # Einsum is faster than GEMV if we have to copy data
516 if not keep_left or not keep_right:
517 return False
519 # We are a matrix-matrix product, but we need to copy data
520 return True
523def _parse_einsum_input(operands):
524 """
525 A reproduction of einsum c side einsum parsing in python.
527 Returns
528 -------
529 input_strings : str
530 Parsed input strings
531 output_string : str
532 Parsed output string
533 operands : list of array_like
534 The operands to use in the numpy contraction
536 Examples
537 --------
538 The operand list is simplified to reduce printing:
540 >>> np.random.seed(123)
541 >>> a = np.random.rand(4, 4)
542 >>> b = np.random.rand(4, 4, 4)
543 >>> _parse_einsum_input(('...a,...a->...', a, b))
544 ('za,xza', 'xz', [a, b]) # may vary
546 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
547 ('za,xza', 'xz', [a, b]) # may vary
548 """
550 if len(operands) == 0:
551 raise ValueError("No input operands")
553 if isinstance(operands[0], str):
554 subscripts = operands[0].replace(" ", "")
555 operands = [asanyarray(v) for v in operands[1:]]
557 # Ensure all characters are valid
558 for s in subscripts:
559 if s in '.,->':
560 continue
561 if s not in einsum_symbols:
562 raise ValueError("Character %s is not a valid symbol." % s)
564 else:
565 tmp_operands = list(operands)
566 operand_list = []
567 subscript_list = []
568 for p in range(len(operands) // 2):
569 operand_list.append(tmp_operands.pop(0))
570 subscript_list.append(tmp_operands.pop(0))
572 output_list = tmp_operands[-1] if len(tmp_operands) else None
573 operands = [asanyarray(v) for v in operand_list]
574 subscripts = ""
575 last = len(subscript_list) - 1
576 for num, sub in enumerate(subscript_list):
577 for s in sub:
578 if s is Ellipsis:
579 subscripts += "..."
580 else:
581 try:
582 s = operator.index(s)
583 except TypeError as e:
584 raise TypeError("For this input type lists must contain "
585 "either int or Ellipsis") from e
586 subscripts += einsum_symbols[s]
587 if num != last:
588 subscripts += ","
590 if output_list is not None:
591 subscripts += "->"
592 for s in output_list:
593 if s is Ellipsis:
594 subscripts += "..."
595 else:
596 try:
597 s = operator.index(s)
598 except TypeError as e:
599 raise TypeError("For this input type lists must contain "
600 "either int or Ellipsis") from e
601 subscripts += einsum_symbols[s]
602 # Check for proper "->"
603 if ("-" in subscripts) or (">" in subscripts):
604 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
605 if invalid or (subscripts.count("->") != 1):
606 raise ValueError("Subscripts can only contain one '->'.")
608 # Parse ellipses
609 if "." in subscripts:
610 used = subscripts.replace(".", "").replace(",", "").replace("->", "")
611 unused = list(einsum_symbols_set - set(used))
612 ellipse_inds = "".join(unused)
613 longest = 0
615 if "->" in subscripts:
616 input_tmp, output_sub = subscripts.split("->")
617 split_subscripts = input_tmp.split(",")
618 out_sub = True
619 else:
620 split_subscripts = subscripts.split(',')
621 out_sub = False
623 for num, sub in enumerate(split_subscripts):
624 if "." in sub:
625 if (sub.count(".") != 3) or (sub.count("...") != 1):
626 raise ValueError("Invalid Ellipses.")
628 # Take into account numerical values
629 if operands[num].shape == ():
630 ellipse_count = 0
631 else:
632 ellipse_count = max(operands[num].ndim, 1)
633 ellipse_count -= (len(sub) - 3)
635 if ellipse_count > longest:
636 longest = ellipse_count
638 if ellipse_count < 0:
639 raise ValueError("Ellipses lengths do not match.")
640 elif ellipse_count == 0:
641 split_subscripts[num] = sub.replace('...', '')
642 else:
643 rep_inds = ellipse_inds[-ellipse_count:]
644 split_subscripts[num] = sub.replace('...', rep_inds)
646 subscripts = ",".join(split_subscripts)
647 if longest == 0:
648 out_ellipse = ""
649 else:
650 out_ellipse = ellipse_inds[-longest:]
652 if out_sub:
653 subscripts += "->" + output_sub.replace("...", out_ellipse)
654 else:
655 # Special care for outputless ellipses
656 output_subscript = ""
657 tmp_subscripts = subscripts.replace(",", "")
658 for s in sorted(set(tmp_subscripts)):
659 if s not in (einsum_symbols):
660 raise ValueError("Character %s is not a valid symbol." % s)
661 if tmp_subscripts.count(s) == 1:
662 output_subscript += s
663 normal_inds = ''.join(sorted(set(output_subscript) -
664 set(out_ellipse)))
666 subscripts += "->" + out_ellipse + normal_inds
668 # Build output string if does not exist
669 if "->" in subscripts:
670 input_subscripts, output_subscript = subscripts.split("->")
671 else:
672 input_subscripts = subscripts
673 # Build output subscripts
674 tmp_subscripts = subscripts.replace(",", "")
675 output_subscript = ""
676 for s in sorted(set(tmp_subscripts)):
677 if s not in einsum_symbols:
678 raise ValueError("Character %s is not a valid symbol." % s)
679 if tmp_subscripts.count(s) == 1:
680 output_subscript += s
682 # Make sure output subscripts are in the input
683 for char in output_subscript:
684 if char not in input_subscripts:
685 raise ValueError("Output character %s did not appear in the input"
686 % char)
688 # Make sure number operands is equivalent to the number of terms
689 if len(input_subscripts.split(',')) != len(operands):
690 raise ValueError("Number of einsum subscripts must be equal to the "
691 "number of operands.")
693 return (input_subscripts, output_subscript, operands)
696def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
697 # NOTE: technically, we should only dispatch on array-like arguments, not
698 # subscripts (given as strings). But separating operands into
699 # arrays/subscripts is a little tricky/slow (given einsum's two supported
700 # signatures), so as a practical shortcut we dispatch on everything.
701 # Strings will be ignored for dispatching since they don't define
702 # __array_function__.
703 return operands
706@array_function_dispatch(_einsum_path_dispatcher, module='numpy')
707def einsum_path(*operands, optimize='greedy', einsum_call=False):
708 """
709 einsum_path(subscripts, *operands, optimize='greedy')
711 Evaluates the lowest cost contraction order for an einsum expression by
712 considering the creation of intermediate arrays.
714 Parameters
715 ----------
716 subscripts : str
717 Specifies the subscripts for summation.
718 *operands : list of array_like
719 These are the arrays for the operation.
720 optimize : {bool, list, tuple, 'greedy', 'optimal'}
721 Choose the type of path. If a tuple is provided, the second argument is
722 assumed to be the maximum intermediate size created. If only a single
723 argument is provided the largest input or output array size is used
724 as a maximum intermediate size.
726 * if a list is given that starts with ``einsum_path``, uses this as the
727 contraction path
728 * if False no optimization is taken
729 * if True defaults to the 'greedy' algorithm
730 * 'optimal' An algorithm that combinatorially explores all possible
731 ways of contracting the listed tensors and choosest the least costly
732 path. Scales exponentially with the number of terms in the
733 contraction.
734 * 'greedy' An algorithm that chooses the best pair contraction
735 at each step. Effectively, this algorithm searches the largest inner,
736 Hadamard, and then outer products at each step. Scales cubically with
737 the number of terms in the contraction. Equivalent to the 'optimal'
738 path for most contractions.
740 Default is 'greedy'.
742 Returns
743 -------
744 path : list of tuples
745 A list representation of the einsum path.
746 string_repr : str
747 A printable representation of the einsum path.
749 Notes
750 -----
751 The resulting path indicates which terms of the input contraction should be
752 contracted first, the result of this contraction is then appended to the
753 end of the contraction list. This list can then be iterated over until all
754 intermediate contractions are complete.
756 See Also
757 --------
758 einsum, linalg.multi_dot
760 Examples
761 --------
763 We can begin with a chain dot example. In this case, it is optimal to
764 contract the ``b`` and ``c`` tensors first as represented by the first
765 element of the path ``(1, 2)``. The resulting tensor is added to the end
766 of the contraction and the remaining contraction ``(0, 1)`` is then
767 completed.
769 >>> np.random.seed(123)
770 >>> a = np.random.rand(2, 2)
771 >>> b = np.random.rand(2, 5)
772 >>> c = np.random.rand(5, 2)
773 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
774 >>> print(path_info[0])
775 ['einsum_path', (1, 2), (0, 1)]
776 >>> print(path_info[1])
777 Complete contraction: ij,jk,kl->il # may vary
778 Naive scaling: 4
779 Optimized scaling: 3
780 Naive FLOP count: 1.600e+02
781 Optimized FLOP count: 5.600e+01
782 Theoretical speedup: 2.857
783 Largest intermediate: 4.000e+00 elements
784 -------------------------------------------------------------------------
785 scaling current remaining
786 -------------------------------------------------------------------------
787 3 kl,jk->jl ij,jl->il
788 3 jl,ij->il il->il
791 A more complex index transformation example.
793 >>> I = np.random.rand(10, 10, 10, 10)
794 >>> C = np.random.rand(10, 10)
795 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
796 ... optimize='greedy')
798 >>> print(path_info[0])
799 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
800 >>> print(path_info[1])
801 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
802 Naive scaling: 8
803 Optimized scaling: 5
804 Naive FLOP count: 8.000e+08
805 Optimized FLOP count: 8.000e+05
806 Theoretical speedup: 1000.000
807 Largest intermediate: 1.000e+04 elements
808 --------------------------------------------------------------------------
809 scaling current remaining
810 --------------------------------------------------------------------------
811 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
812 5 bcde,fb->cdef gc,hd,cdef->efgh
813 5 cdef,gc->defg hd,defg->efgh
814 5 defg,hd->efgh efgh->efgh
815 """
817 # Figure out what the path really is
818 path_type = optimize
819 if path_type is True:
820 path_type = 'greedy'
821 if path_type is None:
822 path_type = False
824 memory_limit = None
826 # No optimization or a named path algorithm
827 if (path_type is False) or isinstance(path_type, str):
828 pass
830 # Given an explicit path
831 elif len(path_type) and (path_type[0] == 'einsum_path'):
832 pass
834 # Path tuple with memory limit
835 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
836 isinstance(path_type[1], (int, float))):
837 memory_limit = int(path_type[1])
838 path_type = path_type[0]
840 else:
841 raise TypeError("Did not understand the path: %s" % str(path_type))
843 # Hidden option, only einsum should call this
844 einsum_call_arg = einsum_call
846 # Python side parsing
847 input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
849 # Build a few useful list and sets
850 input_list = input_subscripts.split(',')
851 input_sets = [set(x) for x in input_list]
852 output_set = set(output_subscript)
853 indices = set(input_subscripts.replace(',', ''))
855 # Get length of each unique dimension and ensure all dimensions are correct
856 dimension_dict = {}
857 broadcast_indices = [[] for x in range(len(input_list))]
858 for tnum, term in enumerate(input_list):
859 sh = operands[tnum].shape
860 if len(sh) != len(term):
861 raise ValueError("Einstein sum subscript %s does not contain the "
862 "correct number of indices for operand %d."
863 % (input_subscripts[tnum], tnum))
864 for cnum, char in enumerate(term):
865 dim = sh[cnum]
867 # Build out broadcast indices
868 if dim == 1:
869 broadcast_indices[tnum].append(char)
871 if char in dimension_dict.keys():
872 # For broadcasting cases we always want the largest dim size
873 if dimension_dict[char] == 1:
874 dimension_dict[char] = dim
875 elif dim not in (1, dimension_dict[char]):
876 raise ValueError("Size of label '%s' for operand %d (%d) "
877 "does not match previous terms (%d)."
878 % (char, tnum, dimension_dict[char], dim))
879 else:
880 dimension_dict[char] = dim
882 # Convert broadcast inds to sets
883 broadcast_indices = [set(x) for x in broadcast_indices]
885 # Compute size of each input array plus the output array
886 size_list = [_compute_size_by_dict(term, dimension_dict)
887 for term in input_list + [output_subscript]]
888 max_size = max(size_list)
890 if memory_limit is None:
891 memory_arg = max_size
892 else:
893 memory_arg = memory_limit
895 # Compute naive cost
896 # This isn't quite right, need to look into exactly how einsum does this
897 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
898 naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
900 # Compute the path
901 if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set):
902 # Nothing to be optimized, leave it to einsum
903 path = [tuple(range(len(input_list)))]
904 elif path_type == "greedy":
905 path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
906 elif path_type == "optimal":
907 path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
908 elif path_type[0] == 'einsum_path':
909 path = path_type[1:]
910 else:
911 raise KeyError("Path name %s not found", path_type)
913 cost_list, scale_list, size_list, contraction_list = [], [], [], []
915 # Build contraction tuple (positions, gemm, einsum_str, remaining)
916 for cnum, contract_inds in enumerate(path):
917 # Make sure we remove inds from right to left
918 contract_inds = tuple(sorted(list(contract_inds), reverse=True))
920 contract = _find_contraction(contract_inds, input_sets, output_set)
921 out_inds, input_sets, idx_removed, idx_contract = contract
923 cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
924 cost_list.append(cost)
925 scale_list.append(len(idx_contract))
926 size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
928 bcast = set()
929 tmp_inputs = []
930 for x in contract_inds:
931 tmp_inputs.append(input_list.pop(x))
932 bcast |= broadcast_indices.pop(x)
934 new_bcast_inds = bcast - idx_removed
936 # If we're broadcasting, nix blas
937 if not len(idx_removed & bcast):
938 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
939 else:
940 do_blas = False
942 # Last contraction
943 if (cnum - len(path)) == -1:
944 idx_result = output_subscript
945 else:
946 sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
947 idx_result = "".join([x[1] for x in sorted(sort_result)])
949 input_list.append(idx_result)
950 broadcast_indices.append(new_bcast_inds)
951 einsum_str = ",".join(tmp_inputs) + "->" + idx_result
953 contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
954 contraction_list.append(contraction)
956 opt_cost = sum(cost_list) + 1
958 if einsum_call_arg:
959 return (operands, contraction_list)
961 # Return the path along with a nice string representation
962 overall_contraction = input_subscripts + "->" + output_subscript
963 header = ("scaling", "current", "remaining")
965 speedup = naive_cost / opt_cost
966 max_i = max(size_list)
968 path_print = " Complete contraction: %s\n" % overall_contraction
969 path_print += " Naive scaling: %d\n" % len(indices)
970 path_print += " Optimized scaling: %d\n" % max(scale_list)
971 path_print += " Naive FLOP count: %.3e\n" % naive_cost
972 path_print += " Optimized FLOP count: %.3e\n" % opt_cost
973 path_print += " Theoretical speedup: %3.3f\n" % speedup
974 path_print += " Largest intermediate: %.3e elements\n" % max_i
975 path_print += "-" * 74 + "\n"
976 path_print += "%6s %24s %40s\n" % header
977 path_print += "-" * 74
979 for n, contraction in enumerate(contraction_list):
980 inds, idx_rm, einsum_str, remaining, blas = contraction
981 remaining_str = ",".join(remaining) + "->" + output_subscript
982 path_run = (scale_list[n], einsum_str, remaining_str)
983 path_print += "\n%4d %24s %40s" % path_run
985 path = ['einsum_path'] + path
986 return (path, path_print)
989def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
990 # Arguably we dispatch on more arguments that we really should; see note in
991 # _einsum_path_dispatcher for why.
992 yield from operands
993 yield out
996# Rewrite einsum to handle different cases
997@array_function_dispatch(_einsum_dispatcher, module='numpy')
998def einsum(*operands, out=None, optimize=False, **kwargs):
999 """
1000 einsum(subscripts, *operands, out=None, dtype=None, order='K',
1001 casting='safe', optimize=False)
1003 Evaluates the Einstein summation convention on the operands.
1005 Using the Einstein summation convention, many common multi-dimensional,
1006 linear algebraic array operations can be represented in a simple fashion.
1007 In *implicit* mode `einsum` computes these values.
1009 In *explicit* mode, `einsum` provides further flexibility to compute
1010 other array operations that might not be considered classical Einstein
1011 summation operations, by disabling, or forcing summation over specified
1012 subscript labels.
1014 See the notes and examples for clarification.
1016 Parameters
1017 ----------
1018 subscripts : str
1019 Specifies the subscripts for summation as comma separated list of
1020 subscript labels. An implicit (classical Einstein summation)
1021 calculation is performed unless the explicit indicator '->' is
1022 included as well as subscript labels of the precise output form.
1023 operands : list of array_like
1024 These are the arrays for the operation.
1025 out : ndarray, optional
1026 If provided, the calculation is done into this array.
1027 dtype : {data-type, None}, optional
1028 If provided, forces the calculation to use the data type specified.
1029 Note that you may have to also give a more liberal `casting`
1030 parameter to allow the conversions. Default is None.
1031 order : {'C', 'F', 'A', 'K'}, optional
1032 Controls the memory layout of the output. 'C' means it should
1033 be C contiguous. 'F' means it should be Fortran contiguous,
1034 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
1035 'K' means it should be as close to the layout as the inputs as
1036 is possible, including arbitrarily permuted axes.
1037 Default is 'K'.
1038 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1039 Controls what kind of data casting may occur. Setting this to
1040 'unsafe' is not recommended, as it can adversely affect accumulations.
1042 * 'no' means the data types should not be cast at all.
1043 * 'equiv' means only byte-order changes are allowed.
1044 * 'safe' means only casts which can preserve values are allowed.
1045 * 'same_kind' means only safe casts or casts within a kind,
1046 like float64 to float32, are allowed.
1047 * 'unsafe' means any data conversions may be done.
1049 Default is 'safe'.
1050 optimize : {False, True, 'greedy', 'optimal'}, optional
1051 Controls if intermediate optimization should occur. No optimization
1052 will occur if False and True will default to the 'greedy' algorithm.
1053 Also accepts an explicit contraction list from the ``np.einsum_path``
1054 function. See ``np.einsum_path`` for more details. Defaults to False.
1056 Returns
1057 -------
1058 output : ndarray
1059 The calculation based on the Einstein summation convention.
1061 See Also
1062 --------
1063 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
1065 Notes
1066 -----
1067 .. versionadded:: 1.6.0
1069 The Einstein summation convention can be used to compute
1070 many multi-dimensional, linear algebraic array operations. `einsum`
1071 provides a succinct way of representing these.
1073 A non-exhaustive list of these operations,
1074 which can be computed by `einsum`, is shown below along with examples:
1076 * Trace of an array, :py:func:`numpy.trace`.
1077 * Return a diagonal, :py:func:`numpy.diag`.
1078 * Array axis summations, :py:func:`numpy.sum`.
1079 * Transpositions and permutations, :py:func:`numpy.transpose`.
1080 * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`.
1081 * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`.
1082 * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`.
1083 * Tensor contractions, :py:func:`numpy.tensordot`.
1084 * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`.
1086 The subscripts string is a comma-separated list of subscript labels,
1087 where each label refers to a dimension of the corresponding operand.
1088 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
1089 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
1090 appears only once, it is not summed, so ``np.einsum('i', a)`` produces a
1091 view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)``
1092 describes traditional matrix multiplication and is equivalent to
1093 :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one
1094 operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent
1095 to :py:func:`np.trace(a) <numpy.trace>`.
1097 In *implicit mode*, the chosen subscripts are important
1098 since the axes of the output are reordered alphabetically. This
1099 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
1100 ``np.einsum('ji', a)`` takes its transpose. Additionally,
1101 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
1102 ``np.einsum('ij,jh', a, b)`` returns the transpose of the
1103 multiplication since subscript 'h' precedes subscript 'i'.
1105 In *explicit mode* the output can be directly controlled by
1106 specifying output subscript labels. This requires the
1107 identifier '->' as well as the list of output subscript labels.
1108 This feature increases the flexibility of the function since
1109 summing can be disabled or forced when required. The call
1110 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`,
1111 and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`.
1112 The difference is that `einsum` does not allow broadcasting by default.
1113 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
1114 order of the output subscript labels and therefore returns matrix
1115 multiplication, unlike the example above in implicit mode.
1117 To enable and control broadcasting, use an ellipsis. Default
1118 NumPy-style broadcasting is done by adding an ellipsis
1119 to the left of each term, like ``np.einsum('...ii->...i', a)``.
1120 To take the trace along the first and last axes,
1121 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
1122 product with the left-most indices instead of rightmost, one can do
1123 ``np.einsum('ij...,jk...->ik...', a, b)``.
1125 When there is only one operand, no axes are summed, and no output
1126 parameter is provided, a view into the operand is returned instead
1127 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
1128 produces a view (changed in version 1.10.0).
1130 `einsum` also provides an alternative way to provide the subscripts
1131 and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
1132 If the output shape is not provided in this format `einsum` will be
1133 calculated in implicit mode, otherwise it will be performed explicitly.
1134 The examples below have corresponding `einsum` calls with the two
1135 parameter methods.
1137 .. versionadded:: 1.10.0
1139 Views returned from einsum are now writeable whenever the input array
1140 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
1141 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
1142 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
1143 of a 2D array.
1145 .. versionadded:: 1.12.0
1147 Added the ``optimize`` argument which will optimize the contraction order
1148 of an einsum expression. For a contraction with three or more operands this
1149 can greatly increase the computational efficiency at the cost of a larger
1150 memory footprint during computation.
1152 Typically a 'greedy' algorithm is applied which empirical tests have shown
1153 returns the optimal path in the majority of cases. In some cases 'optimal'
1154 will return the superlative path through a more expensive, exhaustive search.
1155 For iterative calculations it may be advisable to calculate the optimal path
1156 once and reuse that path by supplying it as an argument. An example is given
1157 below.
1159 See :py:func:`numpy.einsum_path` for more details.
1161 Examples
1162 --------
1163 >>> a = np.arange(25).reshape(5,5)
1164 >>> b = np.arange(5)
1165 >>> c = np.arange(6).reshape(2,3)
1167 Trace of a matrix:
1169 >>> np.einsum('ii', a)
1170 60
1171 >>> np.einsum(a, [0,0])
1172 60
1173 >>> np.trace(a)
1174 60
1176 Extract the diagonal (requires explicit form):
1178 >>> np.einsum('ii->i', a)
1179 array([ 0, 6, 12, 18, 24])
1180 >>> np.einsum(a, [0,0], [0])
1181 array([ 0, 6, 12, 18, 24])
1182 >>> np.diag(a)
1183 array([ 0, 6, 12, 18, 24])
1185 Sum over an axis (requires explicit form):
1187 >>> np.einsum('ij->i', a)
1188 array([ 10, 35, 60, 85, 110])
1189 >>> np.einsum(a, [0,1], [0])
1190 array([ 10, 35, 60, 85, 110])
1191 >>> np.sum(a, axis=1)
1192 array([ 10, 35, 60, 85, 110])
1194 For higher dimensional arrays summing a single axis can be done with ellipsis:
1196 >>> np.einsum('...j->...', a)
1197 array([ 10, 35, 60, 85, 110])
1198 >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
1199 array([ 10, 35, 60, 85, 110])
1201 Compute a matrix transpose, or reorder any number of axes:
1203 >>> np.einsum('ji', c)
1204 array([[0, 3],
1205 [1, 4],
1206 [2, 5]])
1207 >>> np.einsum('ij->ji', c)
1208 array([[0, 3],
1209 [1, 4],
1210 [2, 5]])
1211 >>> np.einsum(c, [1,0])
1212 array([[0, 3],
1213 [1, 4],
1214 [2, 5]])
1215 >>> np.transpose(c)
1216 array([[0, 3],
1217 [1, 4],
1218 [2, 5]])
1220 Vector inner products:
1222 >>> np.einsum('i,i', b, b)
1223 30
1224 >>> np.einsum(b, [0], b, [0])
1225 30
1226 >>> np.inner(b,b)
1227 30
1229 Matrix vector multiplication:
1231 >>> np.einsum('ij,j', a, b)
1232 array([ 30, 80, 130, 180, 230])
1233 >>> np.einsum(a, [0,1], b, [1])
1234 array([ 30, 80, 130, 180, 230])
1235 >>> np.dot(a, b)
1236 array([ 30, 80, 130, 180, 230])
1237 >>> np.einsum('...j,j', a, b)
1238 array([ 30, 80, 130, 180, 230])
1240 Broadcasting and scalar multiplication:
1242 >>> np.einsum('..., ...', 3, c)
1243 array([[ 0, 3, 6],
1244 [ 9, 12, 15]])
1245 >>> np.einsum(',ij', 3, c)
1246 array([[ 0, 3, 6],
1247 [ 9, 12, 15]])
1248 >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
1249 array([[ 0, 3, 6],
1250 [ 9, 12, 15]])
1251 >>> np.multiply(3, c)
1252 array([[ 0, 3, 6],
1253 [ 9, 12, 15]])
1255 Vector outer product:
1257 >>> np.einsum('i,j', np.arange(2)+1, b)
1258 array([[0, 1, 2, 3, 4],
1259 [0, 2, 4, 6, 8]])
1260 >>> np.einsum(np.arange(2)+1, [0], b, [1])
1261 array([[0, 1, 2, 3, 4],
1262 [0, 2, 4, 6, 8]])
1263 >>> np.outer(np.arange(2)+1, b)
1264 array([[0, 1, 2, 3, 4],
1265 [0, 2, 4, 6, 8]])
1267 Tensor contraction:
1269 >>> a = np.arange(60.).reshape(3,4,5)
1270 >>> b = np.arange(24.).reshape(4,3,2)
1271 >>> np.einsum('ijk,jil->kl', a, b)
1272 array([[4400., 4730.],
1273 [4532., 4874.],
1274 [4664., 5018.],
1275 [4796., 5162.],
1276 [4928., 5306.]])
1277 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
1278 array([[4400., 4730.],
1279 [4532., 4874.],
1280 [4664., 5018.],
1281 [4796., 5162.],
1282 [4928., 5306.]])
1283 >>> np.tensordot(a,b, axes=([1,0],[0,1]))
1284 array([[4400., 4730.],
1285 [4532., 4874.],
1286 [4664., 5018.],
1287 [4796., 5162.],
1288 [4928., 5306.]])
1290 Writeable returned arrays (since version 1.10.0):
1292 >>> a = np.zeros((3, 3))
1293 >>> np.einsum('ii->i', a)[:] = 1
1294 >>> a
1295 array([[1., 0., 0.],
1296 [0., 1., 0.],
1297 [0., 0., 1.]])
1299 Example of ellipsis use:
1301 >>> a = np.arange(6).reshape((3,2))
1302 >>> b = np.arange(12).reshape((4,3))
1303 >>> np.einsum('ki,jk->ij', a, b)
1304 array([[10, 28, 46, 64],
1305 [13, 40, 67, 94]])
1306 >>> np.einsum('ki,...k->i...', a, b)
1307 array([[10, 28, 46, 64],
1308 [13, 40, 67, 94]])
1309 >>> np.einsum('k...,jk', a, b)
1310 array([[10, 28, 46, 64],
1311 [13, 40, 67, 94]])
1313 Chained array operations. For more complicated contractions, speed ups
1314 might be achieved by repeatedly computing a 'greedy' path or pre-computing the
1315 'optimal' path and repeatedly applying it, using an
1316 `einsum_path` insertion (since version 1.12.0). Performance improvements can be
1317 particularly significant with larger arrays:
1319 >>> a = np.ones(64).reshape(2,4,8)
1321 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
1323 >>> for iteration in range(500):
1324 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
1326 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
1328 >>> for iteration in range(500):
1329 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
1331 Greedy `einsum` (faster optimal path approximation): ~160ms
1333 >>> for iteration in range(500):
1334 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
1336 Optimal `einsum` (best usage pattern in some use cases): ~110ms
1338 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
1339 >>> for iteration in range(500):
1340 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
1342 """
1343 # Special handling if out is specified
1344 specified_out = out is not None
1346 # If no optimization, run pure einsum
1347 if optimize is False:
1348 if specified_out:
1349 kwargs['out'] = out
1350 return c_einsum(*operands, **kwargs)
1352 # Check the kwargs to avoid a more cryptic error later, without having to
1353 # repeat default values here
1354 valid_einsum_kwargs = ['dtype', 'order', 'casting']
1355 unknown_kwargs = [k for (k, v) in kwargs.items() if
1356 k not in valid_einsum_kwargs]
1357 if len(unknown_kwargs):
1358 raise TypeError("Did not understand the following kwargs: %s"
1359 % unknown_kwargs)
1362 # Build the contraction list and operand
1363 operands, contraction_list = einsum_path(*operands, optimize=optimize,
1364 einsum_call=True)
1366 # Start contraction loop
1367 for num, contraction in enumerate(contraction_list):
1368 inds, idx_rm, einsum_str, remaining, blas = contraction
1369 tmp_operands = [operands.pop(x) for x in inds]
1371 # Do we need to deal with the output?
1372 handle_out = specified_out and ((num + 1) == len(contraction_list))
1374 # Call tensordot if still possible
1375 if blas:
1376 # Checks have already been handled
1377 input_str, results_index = einsum_str.split('->')
1378 input_left, input_right = input_str.split(',')
1380 tensor_result = input_left + input_right
1381 for s in idx_rm:
1382 tensor_result = tensor_result.replace(s, "")
1384 # Find indices to contract over
1385 left_pos, right_pos = [], []
1386 for s in sorted(idx_rm):
1387 left_pos.append(input_left.find(s))
1388 right_pos.append(input_right.find(s))
1390 # Contract!
1391 new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)))
1393 # Build a new view if needed
1394 if (tensor_result != results_index) or handle_out:
1395 if handle_out:
1396 kwargs["out"] = out
1397 new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs)
1399 # Call einsum
1400 else:
1401 # If out was specified
1402 if handle_out:
1403 kwargs["out"] = out
1405 # Do the contraction
1406 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
1408 # Append new items and dereference what we can
1409 operands.append(new_view)
1410 del tmp_operands, new_view
1412 if specified_out:
1413 return out
1414 else:
1415 return operands[0]