#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module contains code generation tools for the ufc::finite_element class.
"""

# Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# First added:  2008-08-13
# Last changed: 2009-03-19

from itertools import izip
import ufl
import swiginac

from sfc.codegeneration.codeformatting import indent, CodeFormatter, gen_switch, gen_token_assignments, gen_const_token_definitions
from sfc.symbolic_utils import symbol, symbols, TempSymbolContext
from sfc.geometry import gen_geometry_code
from sfc.common import sfc_warning, sfc_assert, sfc_error, sfc_info

def cse(*args):
    sfc_error("FIXME")

class FiniteElementCG:
    """Code generator for ufc::finite_element implementations."""
    def __init__(self, elementrep):
        self.rep = elementrep
        self.classname  = elementrep.finite_element_classname
        self.signature  = repr(self.rep.ufl_element)
        self.options = self.rep.options.code.finite_element
        self.DN_order = self.options.evaluate_basis_derivatives_order
    
    def hincludes(self):
        l = []
        return l

    def cincludes(self):
        l = []
        return l

    def generate_code_dict(self):
        # generate the code components:
        vars = {
            "classname"                      : self.classname,
            "constructor"                    : indent(self.gen_constructor()),
            "constructor_arguments"          : indent(self.gen_constructor_arguments()),
            "initializer_list"               : indent(self.gen_initializer_list()),
            "destructor"                     : indent(self.gen_destructor()),
            "create"                         : indent(self.gen_create()),
            "signature"                      : indent(self.gen_signature()),
            "cell_shape"                     : indent(self.gen_cell_shape()),
            "geometric_dimension"            : indent(self.gen_geometric_dimension()),
            "topological_dimension"          : indent(self.gen_topological_dimension()),
            "space_dimension"                : indent(self.gen_space_dimension()),
            "value_rank"                     : indent(self.gen_value_rank()),
            "value_dimension"                : indent(self.gen_value_dimension()),
            "map_from_reference_cell"        : indent(self.gen_map_from_reference_cell()),
            "map_to_reference_cell"          : indent(self.gen_map_to_reference_cell()),
            "evaluate_basis"                 : indent(self.gen_evaluate_basis()),
            "evaluate_basis_all"             : indent(self.gen_evaluate_basis_all()),
            "evaluate_basis_derivatives"     : indent(self.gen_evaluate_basis_derivatives()),
            "evaluate_basis_derivatives_all" : indent(self.gen_evaluate_basis_derivatives_all()),
            "evaluate_dof"                   : indent(self.gen_evaluate_dof()),
            "evaluate_dofs"                  : indent(self.gen_evaluate_dofs()),
            "interpolate_vertex_values"      : indent(self.gen_interpolate_vertex_values()),
            "num_sub_elements"               : indent(self.gen_num_sub_elements()),
            "create_sub_element"             : indent(self.gen_create_sub_element()),
            "members"                        : indent(self.gen_members()),
        }
        return vars

    def generate_support_code(self):
        return ""

    def gen_constructor(self):
        return ""

    def gen_constructor_arguments(self):
        return ""

    def gen_initializer_list(self):
        return ""

    def gen_destructor(self):
        return ""

    def gen_create(self):
        code = "return new %s();" % self.classname
        return code

    def gen_signature(self):
        """const char* signature() const"""
        return 'return "%s";' % self.signature

    def gen_cell_shape(self):
        """shape cell_shape() const"""
        return "return ufc::%s;" % self.rep.cell.shape

    def gen_geometric_dimension(self):
        return "return %d;" % self.rep.cell.nsd

    def gen_topological_dimension(self):
        return "return %d;" % self.rep.cell.nsd

    def gen_space_dimension(self):
        """unsigned int space_dimension() const"""
        return "return %d;" % self.rep.local_dimension

    def gen_value_rank(self):
        """unsigned int value_rank() const"""
        return "return %d;" % self.rep.value_rank

    def gen_value_dimension(self):
        """unsigned int value_dimension(unsigned int i) const"""
        if self.rep.value_rank == 0:
            code = 'throw std::runtime_error("Rank 0 value has no dimension.");\n'
        #elif self.rep.value_rank == 1:
        #    code = "return %d;" % self.rep.value_shape[0]
        else:
            dims = self.rep.value_shape
            cases = [(i, "return %d;" % d) for (i,d) in enumerate(dims)]
            code = gen_switch("i", cases)
            code += 'throw std::runtime_error("Invalid dimension for rank %d value.");\n' % self.rep.value_rank

        return code

    def gen_map_from_reference_cell(self):
        return 'throw std::runtime_error("Not implemented.");' # FIXME

    def gen_map_to_reference_cell(self):
        return 'throw std::runtime_error("Not implemented.");' # FIXME

    def gen_evaluate_basis(self):
        """void evaluate_basis(unsigned int i,
                                double* values,
                                const double* coordinates,
                                const cell& c) const
        """
        if not self.options.enable_evaluate_basis:
            return 'throw std::runtime_error("evaluate_basis not implemented.");'

        nsd = self.rep.cell.nsd
        nbf = self.rep.local_dimension
        value_shape = self.rep.value_shape
        value_size = self.rep.value_size

        if self.rep.ufl_element.family() == "Real":
            code = []
            if value_size > 1:
                code += ["memset(values, 0, sizeof(double)*%d);" % value_size]
            code += ['values[i] = 1.0;']
            return '\n'.join(code)

        # symbols for output values
        val_sym = symbols(["values[%d]" % d for d in range(value_size)])

        # generate code body
        code = CodeFormatter()

        code += gen_geometry_code(nsd, detG=False, GinvT=True)

        coordinates = swiginac.matrix(nsd, 1, symbols(["coordinates[%d]" % i for i in range(nsd)]))
        xi = self.rep.GinvT.transpose().mul(coordinates - self.rep.p0).evalm()

        for i in range(nsd):
            code += "const double %s = %s;" % (("x","y","z")[i], xi[i].printc())
        
        if value_size > 1:
            code += "memset(values, 0, sizeof(double)*%d);" % value_size
            code += ""
        
        # begin switch
        if nbf>1:
            code.begin_switch("i")
        
        # generate one case for each basis function
        for i in range(nbf):
            
            # make token list for basis function i
            values = []
            for c in self.rep.value_components:
                values.append( self.rep.basis_function(i, c) )
            values_tokens = []
            for d in range(value_size):
                if not values[d].expand().is_zero():
                    values_tokens.append( (val_sym[d], values[d]) )
            # now values_tokens is a token list with "values[i] = expression" for all output values
            
            if self.options.optimize_basis:
                # split tokens into temporary variables (1) and output variables (2)
                temp_symbol = TempSymbolContext()
                values_tokens1 = []
                values_tokens2 = []
                for s, e in values_tokens:           # handle s = e:
                    ts = temp_symbol()               #   construct ts
                    values_tokens1.append( (ts, e) ) #   set ts = e
                    values_tokens2.append( (s, ts) ) #   set s = ts
                values_tokens = None
                
                # optimize temporary variable list
                values_tokens1, repmap = cse(values_tokens1, temp_symbol)
                
                # generate code
                values_code = gen_const_token_definitions(values_tokens1) + "\n"
                values_code += gen_token_assignments(values_tokens2)
                #values_codelines= chain(const_token_definitions(values_tokens1), token_assignments(values_tokens2))
            else:
                values_code = gen_token_assignments(values_tokens)
                #values_codelines = token_assignments(values_tokens)
            
            # generate case code
            if nbf>1:
                code.begin_case(i, braces=True)
                code.new_line( indent(values_code) )
                code.end_case()
                #with code.case(i, braces=True):
                #    code.add_lines(values_codelines)
            else:
                code += values_code
        
        if nbf>1:
            code.end_switch()
        
        return str(code)

    def gen_evaluate_basis_all(self): # TODO: implement optimized version of this
        """/// Evaluate all basis functions at given point in cell
           virtual void evaluate_basis_all(double* values,
                                           const double* coordinates,
                                           const cell& c) const
        """

        code = CodeFormatter()
        code += "for(unsigned i = 0; i < %d; i++)" % self.rep.local_dimension
        code.begin_block()
        code += "evaluate_basis(i, values+i*%d, coordinates, c);" % self.rep.value_size
        code.end_block()
        return str(code)

    def gen_evaluate_basis_derivatives(self):
        """/// Evaluate order n derivatives of basis function i at given point in cell
           void evaluate_basis_derivatives(unsigned int i,
                                            unsigned int n,
                                            double* values,
                                            const double* coordinates,
                                            const ufc::cell& c) const
        """
        if not self.options.enable_evaluate_basis_derivatives:
            return 'throw std::runtime_error("evaluate_basis_derivatives not implemented.");'
        
        nsd = self.rep.cell.nsd
        nbf = self.rep.local_dimension
        value_shape = self.rep.value_shape
        value_size = self.rep.value_size

        if self.rep.ufl_element.family() == "Real":
            return '\n'.join('values[%d] = 0.0;' % d for d in range(value_size))
        
        sfc_assert(self.DN_order <= 2, "Don't support computing higher order derivatives (yet).")
        
        code = CodeFormatter()
        code.begin_if("n > 2")
        code += 'throw std::runtime_error("evaluate_basis_derivatives not implemented for the wanted derivative order.");'
        code.end_if()
        
        # define GinvT
        code += gen_geometry_code(nsd, detG=False, GinvT=True)
        
        # define x,y,z (spatial symbols)
        p      = symbols(["x", "y", "z"][:nsd])
        coords = symbols(["coordinates[%d]" % i  for i in range(nsd)])
        code += gen_const_token_definitions( izip(p, coords) )
        
        # switch on derivative order
        code.begin_switch("n")
        
        for order in range(1, self.DN_order+1):
            code.begin_case(order, braces=True)
            
            # zero output array before filling nonzeros
            num_derivatives = nsd ** order
            do_memset = (value_size*num_derivatives) > 6
            if do_memset:
                code += "memset(values, 0, sizeof(double) * %d * %d);" % (value_size, num_derivatives)
                code += ""
            
            # switch on basis function number
            code.begin_switch("i")
            for ibf in range(nbf):
                code.begin_case(ibf, braces=True)
                temp_symbol = TempSymbolContext()
                
                DN_tokens = []
                all_directions = ufl.permutation.compute_permutations(self.DN_order, nsd)
                for (j,directions) in enumerate(all_directions):
                    for (k,c) in enumerate(self.rep.value_components):
                        DN = self.rep.basis_function_derivative(ibf, c, directions)
                        # DN is now the expression for the derivative wrt directions of component c of basis function ibf 
                        if not (do_memset and DN.expand().is_zero()):
                            values_sym = symbol("values[%d * %d + %d]" % (j, value_size, k))
                            DN_tokens.append( (values_sym, DN) )
                
                if self.options.optimize_basis: 
                    # split tokens into temporary variables and output variables
                    DN_tokens1 = []
                    DN_tokens2 = []
                    for s, e in DN_tokens:           # s  = e
                        ts = temp_symbol()           #   ->
                        DN_tokens1.append( (ts, e) ) # ts = e
                        DN_tokens2.append( (s, ts) ) # s  = ts
                    # optimize temporary variable list
                    DN_tokens1, repmap = cse(DN_tokens1, temp_symbol)
                    # generate code
                    code += gen_const_token_definitions(DN_tokens1)
                    code += gen_token_assignments(DN_tokens2)
                else:
                    code += gen_token_assignments(DN_tokens)
                code.end_case()
            code.end_switch()            
            code.end_case()
        code += "default:"
        code.indent()
        code += 'throw std::runtime_error("Derivatives of this order are not supported in evaluate_basis_derivatives.");'
        code.dedent()
        code.end_switch()
        return str(code)
    
    def gen_evaluate_basis_derivatives_all(self): # TODO: implement optimized version of this
        """/// Evaluate order n derivatives of all basis functions at given point in cell
           virtual void evaluate_basis_derivatives_all(unsigned int n,
                                                       double* values,
                                                       const double* coordinates,
                                                       const cell& c) const
        """
        code = CodeFormatter()
        code.begin_switch("n")
        for order in range(1, self.DN_order+1):
            code.begin_case(order, braces=True)
            offset = self.rep.value_size * (self.rep.cell.nsd ** order)
            code += "for(unsigned i = 0; i < %d; i++)" % self.rep.local_dimension
            code.begin_block()
            code += "evaluate_basis_derivatives(n, i, values+i*%d, coordinates, c);" % offset
            code.end_block()
            code.end_case()
        code += "default:"
        code.indent()
        code += 'throw std::runtime_error("Derivatives of this order are not implemented in evaluate_basis_derivatives_all.");'
        code.dedent()
        code.end_switch()
        return str(code)

    def gen_evaluate_dof(self):
        """double evaluate_dof(unsigned int i,
                               const function& f,
                               const cell& c) const
           This implementation is general for all elements with point evaluation dofs.

           TODO: Implement support for normal and tangential component dofs.
           TODO: Implement for elements without point evaluation dofs.
        """
        # some useful variables
        nsd = self.rep.cell.nsd
        nbf = self.rep.local_dimension
        
        # initial code
        code = CodeFormatter()
        code.new_text( gen_geometry_code(nsd, detG=False) )
        code += "double v[%d];" % self.rep.value_size
        code += "double x[%d];" % nsd

        # fill global coordinates of dof in x[]
        if nbf == 1:
            # skip the switch if only one basis function
            for k in range(nsd):
                code += "x[%d] = %s;" % (k, self.rep.dof_x[0][k].printc())
        else:
            code.begin_switch("i")
            for i in range(nbf):
                code.begin_case(i)
                # compute dof coordinate i
                for k in range(nsd):
                    code += "x[%d] = %s;" % (k, self.rep.dof_x[i][k].printc())
                code.end_case()
            code.end_switch()

        # Evaluate the function (this evaluates all value components!)
        code += "f.evaluate(v, x, c);"

        # dofs for a single sub element are numbered contiguously:
        # i = (nbf / valsize) * value_component
        if nbf == 1:
            code += "return v[0];"
        else:
            code += "return v[i / %d];" % (nbf // self.rep.value_size)
        
        return str(code)

    def gen_evaluate_dofs(self): # TODO: implement optimized version of this
        """/// Evaluate linear functionals for all dofs on the function f
           virtual void evaluate_dofs(double* values,
                                      const function& f,
                                      const cell& c) const
        """
        code = CodeFormatter()
        code += "for(unsigned i=0; i<%d; i++)" % self.rep.local_dimension
        code.begin_block()
        code += "values[i] = evaluate_dof(i, f, c);"
        code.end_block()
        return str(code)

    def gen_interpolate_vertex_values(self):
        """void interpolate_vertex_values(double* vertex_values,
                                           const double* dof_values,
                                           const cell& c) const
        """
        # some helper variables
        nbf = self.rep.local_dimension
        nsd = self.rep.cell.nsd
        nv  = self.rep.cell.num_entities[0]
        value_size = self.rep.value_size
        value_shape = self.rep.value_shape

        # symbols for input array entries
        dof_values_sym    = symbols("dof_values[%d]" % i    for i in xrange(nbf))
        vertex_values_sym = symbols("vertex_values[%d]" % i for i in xrange(nv*value_size))

        # the spatial symbols u is expressed in
        p = self.rep.p

        # construct expressions for the linear combinations of basis functions
        u = []
        for component in self.rep.value_components:
            u.append( sum(self.rep.basis_function(j, component)*dof_values_sym[j] for j in range(nbf)) )
        
        # for each vertex 
        vertex_values = []
        repmap = swiginac.exmap()
        for i in range(nv):
            # replacement map for coordinates
            vx = self.rep.polygon.vertex(i)
            for k in range(nsd):
                repmap[p[k]] = vx[k]
            # evaluate functions for each component in coordinate
            for uval in u:
                vertex_values.append(uval.subs(repmap))
        
        code = gen_token_assignments( izip(vertex_values_sym, vertex_values) )
        return code

    def gen_num_sub_elements(self):
        """unsigned int num_sub_elements() const"""
        return  "return %d;" % len(self.rep.sub_elements)

    def gen_create_sub_element(self):
        """finite_element* create_sub_element(unsigned int i) const"""
        if len(self.rep.sub_elements) > 1:
            code = CodeFormatter()
            code.begin_switch("i")
            for i, fe in enumerate(self.rep.sub_elements):
                code += "case %d: return new %s();" % (i, fe.finite_element_classname)
            code.end_switch()
            code += 'throw std::runtime_error("Invalid index in create_sub_element.");'
        else:
            code = "return new %s();" % self.classname # FIXME: Should we throw error here instead now?
        return str(code)

    def gen_members(self):
        return ""
        #code  = CodeFormatter()
        #code += "public:";
        #code += "protected:";
        #code.indent()
        #code += "unsigned int foo = 0;"
        #code.dedent()
        #return str(code)

