@@ -460,7 +460,7 @@ def power(self, a, exponents):
460460 """
461461 raise NotImplementedError ()
462462
463- def norm (self , a ):
463+ def norm (self , a , axis = None , keepdims = False ):
464464 r"""
465465 Computes the matrix frobenius norm.
466466
@@ -680,7 +680,7 @@ def diag(self, a, k=0):
680680 """
681681 raise NotImplementedError ()
682682
683- def unique (self , a ):
683+ def unique (self , a , return_inverse = False ):
684684 r"""
685685 Finds unique elements of given tensor.
686686
@@ -1140,8 +1140,8 @@ def sqrt(self, a):
11401140 def power (self , a , exponents ):
11411141 return np .power (a , exponents )
11421142
1143- def norm (self , a ):
1144- return np .sqrt ( np . sum ( np . square ( a )) )
1143+ def norm (self , a , axis = None , keepdims = False ):
1144+ return np .linalg . norm ( a , axis = axis , keepdims = keepdims )
11451145
11461146 def any (self , a ):
11471147 return np .any (a )
@@ -1217,8 +1217,8 @@ def meshgrid(self, a, b):
12171217 def diag (self , a , k = 0 ):
12181218 return np .diag (a , k )
12191219
1220- def unique (self , a ):
1221- return np .unique (a )
1220+ def unique (self , a , return_inverse = False ):
1221+ return np .unique (a , return_inverse = return_inverse )
12221222
12231223 def logsumexp (self , a , axis = None ):
12241224 return special .logsumexp (a , axis = axis )
@@ -1514,8 +1514,8 @@ def sqrt(self, a):
15141514 def power (self , a , exponents ):
15151515 return jnp .power (a , exponents )
15161516
1517- def norm (self , a ):
1518- return jnp .sqrt ( jnp . sum ( jnp . square ( a )) )
1517+ def norm (self , a , axis = None , keepdims = False ):
1518+ return jnp .linalg . norm ( a , axis = axis , keepdims = keepdims )
15191519
15201520 def any (self , a ):
15211521 return jnp .any (a )
@@ -1588,8 +1588,8 @@ def meshgrid(self, a, b):
15881588 def diag (self , a , k = 0 ):
15891589 return jnp .diag (a , k )
15901590
1591- def unique (self , a ):
1592- return jnp .unique (a )
1591+ def unique (self , a , return_inverse = False ):
1592+ return jnp .unique (a , return_inverse = return_inverse )
15931593
15941594 def logsumexp (self , a , axis = None ):
15951595 return jspecial .logsumexp (a , axis = axis )
@@ -1934,8 +1934,8 @@ def sqrt(self, a):
19341934 def power (self , a , exponents ):
19351935 return torch .pow (a , exponents )
19361936
1937- def norm (self , a ):
1938- return torch .sqrt ( torch . sum ( torch . square ( a )) )
1937+ def norm (self , a , axis = None , keepdims = False ):
1938+ return torch .linalg . norm ( a . double (), dim = axis , keepdims = keepdims )
19391939
19401940 def any (self , a ):
19411941 return torch .any (a )
@@ -2039,8 +2039,8 @@ def meshgrid(self, a, b):
20392039 def diag (self , a , k = 0 ):
20402040 return torch .diag (a , diagonal = k )
20412041
2042- def unique (self , a ):
2043- return torch .unique (a )
2042+ def unique (self , a , return_inverse = False ):
2043+ return torch .unique (a , return_inverse = return_inverse )
20442044
20452045 def logsumexp (self , a , axis = None ):
20462046 if axis is not None :
@@ -2359,8 +2359,8 @@ def power(self, a, exponents):
23592359 def dot (self , a , b ):
23602360 return cp .dot (a , b )
23612361
2362- def norm (self , a ):
2363- return cp .sqrt ( cp . sum ( cp . square ( a )) )
2362+ def norm (self , a , axis = None , keepdims = False ):
2363+ return cp .linalg . norm ( a , axis = axis , keepdims = keepdims )
23642364
23652365 def any (self , a ):
23662366 return cp .any (a )
@@ -2436,8 +2436,8 @@ def meshgrid(self, a, b):
24362436 def diag (self , a , k = 0 ):
24372437 return cp .diag (a , k )
24382438
2439- def unique (self , a ):
2440- return cp .unique (a )
2439+ def unique (self , a , return_inverse = False ):
2440+ return cp .unique (a , return_inverse = return_inverse )
24412441
24422442 def logsumexp (self , a , axis = None ):
24432443 # Taken from
@@ -2770,8 +2770,8 @@ def sqrt(self, a):
27702770 def power (self , a , exponents ):
27712771 return tnp .power (a , exponents )
27722772
2773- def norm (self , a ):
2774- return tf .math .reduce_euclidean_norm (a )
2773+ def norm (self , a , axis = None , keepdims = False ):
2774+ return tf .math .reduce_euclidean_norm (a , axis = axis , keepdims = keepdims )
27752775
27762776 def any (self , a ):
27772777 return tnp .any (a )
@@ -2843,8 +2843,15 @@ def meshgrid(self, a, b):
28432843 def diag (self , a , k = 0 ):
28442844 return tnp .diag (a , k )
28452845
2846- def unique (self , a ):
2847- return tf .sort (tf .unique (tf .reshape (a , [- 1 ]))[0 ])
2846+ def unique (self , a , return_inverse = False ):
2847+ y , idx = tf .unique (tf .reshape (a , [- 1 ]))
2848+ sort_idx = tf .argsort (y )
2849+ y_prime = tf .gather (y , sort_idx )
2850+ if return_inverse :
2851+ inv_sort_idx = tf .math .invert_permutation (sort_idx )
2852+ return y_prime , tf .gather (inv_sort_idx , idx )
2853+ else :
2854+ return y_prime
28482855
28492856 def logsumexp (self , a , axis = None ):
28502857 return tf .math .reduce_logsumexp (a , axis = axis )
0 commit comments