/*
 *  Copyright 2008-2013 NVIDIA Corporation
 *  Copyright 2013 Filipe RNC Maia
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#include <thrust/complex.h>
#include <cfloat>
#include <cmath>
#include <thrust/detail/complex/c99math.h>

namespace thrust
{

  /* --- Binary Arithmetic Operators --- */

template<typename ValueType>
__host__ __device__ 
inline complex<ValueType> operator+(const complex<ValueType>& lhs,
				      const complex<ValueType>& rhs){
  return complex<ValueType>(lhs.real()+rhs.real(),lhs.imag()+rhs.imag());
}

template<typename ValueType>
__host__ __device__ 
inline complex<ValueType> operator+(const volatile complex<ValueType>& lhs,
				      const volatile complex<ValueType>& rhs){
  return complex<ValueType>(lhs.real()+rhs.real(),lhs.imag()+rhs.imag());
}

template <typename ValueType> 
__host__ __device__ 
inline complex<ValueType> operator+(const complex<ValueType>& lhs, const ValueType & rhs){
  return complex<ValueType>(lhs.real()+rhs,lhs.imag());
}

template <typename ValueType> 
__host__ __device__ 
inline complex<ValueType> operator+(const ValueType& lhs, const complex<ValueType>& rhs){
  return complex<ValueType>(rhs.real()+lhs,rhs.imag());
}

template <typename ValueType> 
__host__ __device__ 
inline complex<ValueType> operator-(const complex<ValueType>& lhs, const complex<ValueType>& rhs){
  return complex<ValueType>(lhs.real()-rhs.real(),lhs.imag()-rhs.imag());
}

template <typename ValueType> 
__host__ __device__
inline complex<ValueType> operator-(const complex<ValueType>& lhs, const ValueType & rhs){
  return complex<ValueType>(lhs.real()-rhs,lhs.imag());
}

template <typename ValueType> 
__host__ __device__
inline complex<ValueType> operator-(const ValueType& lhs, const complex<ValueType>& rhs){
  return complex<ValueType>(lhs-rhs.real(),-rhs.imag());
}

template <typename ValueType> 
__host__ __device__
inline complex<ValueType> operator*(const complex<ValueType>& lhs,
				      const complex<ValueType>& rhs){
  return complex<ValueType>(lhs.real()*rhs.real()-lhs.imag()*rhs.imag(),
			    lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
}

template <typename ValueType> 
__host__ __device__
inline complex<ValueType> operator*(const complex<ValueType>& lhs, const ValueType & rhs){
  return complex<ValueType>(lhs.real()*rhs,lhs.imag()*rhs);
}

template <typename ValueType> 
__host__ __device__
inline complex<ValueType> operator*(const ValueType& lhs, const complex<ValueType>& rhs){
  return complex<ValueType>(rhs.real()*lhs,rhs.imag()*lhs);
}


template <typename ValueType>
__host__ __device__
inline complex<ValueType> operator/(const complex<ValueType>& lhs, const complex<ValueType>& rhs){
  ValueType s = std::abs(rhs.real()) + std::abs(rhs.imag());
  ValueType oos = ValueType(1.0) / s;
  ValueType ars = lhs.real() * oos;
  ValueType ais = lhs.imag() * oos;
  ValueType brs = rhs.real() * oos;
  ValueType bis = rhs.imag() * oos;
  s = (brs * brs) + (bis * bis);
  oos = ValueType(1.0) / s;
  complex<ValueType> quot(((ars * brs) + (ais * bis)) * oos,
			 ((ais * brs) - (ars * bis)) * oos);
  return quot;
}

template <typename ValueType> 
  __host__ __device__
  inline complex<ValueType> operator/(const complex<ValueType>& lhs, const ValueType & rhs){
  return complex<ValueType>(lhs.real()/rhs,lhs.imag()/rhs);
}

template <typename ValueType>
  __host__ __device__
  inline complex<ValueType> operator/(const ValueType& lhs, const complex<ValueType>& rhs){
  return complex<ValueType>(lhs)/rhs;
}



/* --- Unary Arithmetic Operators --- */

template <typename ValueType> 
  __host__ __device__
  inline complex<ValueType> operator+(const complex<ValueType>& rhs){
  return rhs;
}

template <typename ValueType> 
  __host__ __device__
  inline complex<ValueType> operator-(const complex<ValueType>& rhs){
  return rhs*-ValueType(1);
}


/* --- Other Basic Arithmetic Functions --- */

// As std::hypot is only C++11 we have to use the C interface
template <typename ValueType>
  __host__ __device__
  inline ValueType abs(const complex<ValueType>& z){
  return hypot(z.real(),z.imag());
}

namespace detail{
namespace complex{	
__host__ __device__ inline float abs(const thrust::complex<float>& z){
  return hypotf(z.real(),z.imag());
}

__host__ __device__ inline double abs(const thrust::complex<double>& z){
  return hypot(z.real(),z.imag());
}
}
}

template <>
  __host__ __device__
  inline float abs(const complex<float>& z){
  return detail::complex::abs(z);
}
template<>
  __host__ __device__
  inline double abs(const complex<double>& z){
  return detail::complex::abs(z);
}


template <typename ValueType>
  __host__ __device__
  inline ValueType arg(const complex<ValueType>& z){
  return std::atan2(z.imag(),z.real());
}

template <typename ValueType>
  __host__ __device__
  inline complex<ValueType> conj(const complex<ValueType>& z){
  return complex<ValueType>(z.real(),-z.imag());
}

template <typename ValueType>
  __host__ __device__
  inline ValueType norm(const complex<ValueType>& z){
  return z.real()*z.real() + z.imag()*z.imag();
}

template <>
  __host__ __device__
  inline float norm(const complex<float>& z){
  if(std::abs(z.real()) < ::sqrtf(FLT_MIN) && std::abs(z.imag()) < ::sqrtf(FLT_MIN)){
    float a = z.real()*4.0f;
    float b = z.imag()*4.0f;
    return (a*a+b*b)/16.0f;
  } 
  return z.real()*z.real() + z.imag()*z.imag();
}

template <>
  __host__ __device__
  inline double norm(const complex<double>& z){
  if(std::abs(z.real()) < ::sqrt(DBL_MIN) && std::abs(z.imag()) < ::sqrt(DBL_MIN)){
    double a = z.real()*4.0;
    double b = z.imag()*4.0;
    return (a*a+b*b)/16.0;
  } 
  return z.real()*z.real() + z.imag()*z.imag();
}

template <typename ValueType>
  __host__ __device__
  inline complex<ValueType> polar(const ValueType & m, const ValueType & theta){ 
  return complex<ValueType>(m * std::cos(theta),m * std::sin(theta));
}

}


