/* -*-c++-*- OpenSceneGraph - Copyright (C) Cedric Pinson
 *
 * This application is open source and may be redistributed and/or modified
 * freely and without restriction, both in commercial and non commercial
 * applications, as long as this copyright notice is maintained.
 *
 * This application is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 *
*/

#ifndef GEOMETRY_SPLITTER_VISITOR
#define GEOMETRY_SPLITTER_VISITOR

#include <set>
#include <algorithm>

#include <osg/ref_ptr>
#include <osg/Geometry>
#include <osg/PrimitiveSet>
#include <osg/ValueObject>
#include <osgUtil/MeshOptimizers>

#include "GeometryArray"
#include "GeometryUniqueVisitor"
#include "glesUtil"


class GeometryIndexSplitter
{
public:
    typedef std::vector<osg::ref_ptr<osg::Geometry> > GeometryList;

    GeometryIndexSplitter(unsigned int maxIndex, bool disablePostTransform):
        _maxIndexToSplit(maxIndex), _disablePostTransform(disablePostTransform)
    {}

    bool split(osg::Geometry& geometry) {
        if(!hasValidPrimitives(geometry) || !needToSplit(geometry)) {
            if(!_disablePostTransform) {
                // optimize cache for rendering
                glesUtil::VertexCacheVisitor posttransform;
                posttransform.optimizeVertices(geometry);
            }
            _geometryList.push_back(&geometry); // remap geometry to itself
            return false;
        }

        // keep bounding box data as user value if needed in subsequent processing
        attachBufferBoundingBox(geometry);

        // first optimize primitives indexing
        {
            if(!_disablePostTransform) {
                // post-transform for better locality
                glesUtil::VertexCacheVisitor posttransform;
                posttransform.optimizeVertices(geometry);
            }
            // pre-transform to reindex correctly
            glesUtil::VertexAccessOrderVisitor pretransform;
            pretransform.optimizeOrder(geometry);
        }

        // Clone geometry as we will modify vertex arrays and primitives
        // To avoid issues with shared buffers, arrays & primitives should be
        // deep cloned.
        // As UserValues might be changed on a per-geometry basis afterwards, we
        // also deep clone userdata
        // We do *not* want to clone statesets as they reference a UniqueID that
        // should be unique (see #83056464).
        osg::ref_ptr<osg::Geometry> processing = osg::clone(&geometry, osg::CopyOp::DEEP_COPY_ARRAYS |
                                                                       osg::CopyOp::DEEP_COPY_PRIMITIVES |
                                                                       osg::CopyOp::DEEP_COPY_USERDATA);
        osg::ref_ptr<osg::Geometry> reported;
        while (true) {
            reported = doSplit(*processing);

            // reduce vertex array if needed
            if(processing && processing->getNumPrimitiveSets()) {
                GeometryArrayList arrayList(*processing);
                arrayList.setNumElements(osg::minimum(arrayList.size(), _maxIndexToSplit + 1));
                _geometryList.push_back(processing);
            }

            if (!reported.valid()) {
                break;
            }
            else {
                processing = reported;
                reported = 0;

                // re order index elements
                glesUtil::VertexAccessOrderVisitor preTransform;
                preTransform.optimizeOrder(*processing);
            }
        }

        osg::notify(osg::NOTICE) << "geometry " << &geometry << " " << geometry.getName()
                                 << " vertexes (" << geometry.getVertexArray()->getNumElements()
                                 << ") has DrawElements index > " << _maxIndexToSplit << ", splitted to "
                                 << _geometryList.size() << " geometry" << std::endl;

        return true;
    }

protected:
    bool hasValidPrimitives(osg::Geometry& geometry) const {
        for (unsigned int i = 0; i < geometry.getNumPrimitiveSets(); ++ i) {
            osg::PrimitiveSet* primitive = geometry.getPrimitiveSet(i);
            if (primitive) {
                if (!primitive->getDrawElements()) {
                    osg::notify(osg::INFO) << "can't split Geometry " << geometry.getName()
                                           << " (" << &geometry << ") contains non indexed primitives"
                                           << std::endl;
                    return false;
                }

                switch (primitive->getMode()) {
                    case osg::PrimitiveSet::TRIANGLES:
                    case osg::PrimitiveSet::LINES:
                    case osg::PrimitiveSet::POINTS:
                        break;
                    default:
                        osg::notify(osg::FATAL) << "can't split Geometry " << geometry.getName()
                                                << " (" << &geometry << ") contains non point/line/triangle primitives"
                                                << std::endl;
                        return false;
                        break;
                }
            }
        }
        return true;
    }

    bool needToSplit(const osg::Geometry& geometry) const {
        for(unsigned int i = 0; i < geometry.getNumPrimitiveSets(); ++ i) {
            const osg::DrawElements* primitive = geometry.getPrimitiveSet(i)->getDrawElements();
            if (needToSplit(*primitive)) {
                return true;
            }
        }
        return false;
    }

    bool needToSplit(const osg::DrawElements& primitive) const {
        for(unsigned int j = 0; j < primitive.getNumIndices(); j++) {
            if (primitive.index(j) > _maxIndexToSplit){
                return true;
            }
        }
        return false;
    }

    void attachBufferBoundingBox(osg::Geometry& geometry) const {
        // positions
        setBufferBoundingBox(dynamic_cast<osg::Vec3Array*>(geometry.getVertexArray()));
        // uvs
        for(unsigned int i = 0 ; i < geometry.getNumTexCoordArrays() ; ++ i) {
            setBufferBoundingBox(dynamic_cast<osg::Vec2Array*>(geometry.getTexCoordArray(i)));
        }
    }

    template<typename T>
    void setBufferBoundingBox(T* buffer) const {
        if(!buffer) return;

        typename T::ElementDataType bbl;
        typename T::ElementDataType ufr;
        const unsigned int dimension = buffer->getDataSize();

        if(buffer->getNumElements()) {
            for(unsigned int i = 0 ; i < dimension ; ++i) {
                bbl[i] = ufr[i] = (*buffer->begin())[i];
            }

            for(typename T::const_iterator it = buffer->begin() + 1 ; it != buffer->end() ; ++ it) {
                for(unsigned int i = 0 ; i < dimension ; ++ i) {
                    bbl[i] = std::min(bbl[i], (*it)[i]);
                    ufr[i] = std::max(ufr[i], (*it)[i]);
                }
            }

            buffer->setUserValue("bbl", bbl);
            buffer->setUserValue("ufr", ufr);
        }
    }

    osg::Geometry* doSplit(osg::Geometry& geometry) const {
        osg::Geometry::PrimitiveSetList geomPrimitives;
        osg::Geometry::PrimitiveSetList wirePrimitives;
        osg::Geometry::PrimitiveSetList reportedPrimitives;
        std::vector< osg::ref_ptr<osg::DrawElements> > primitivesToFix;

        osg::Geometry::PrimitiveSetList& primitives = geometry.getPrimitiveSetList();
        std::set<unsigned int> validIndices;
        osg::ref_ptr<osg::Geometry> reportGeometry;

        for (unsigned int i = 0; i < primitives.size(); i++) {
            osg::DrawElements *primitive = primitives[i]->getDrawElements();

            bool isWireframe = false;
            if(primitive->getUserValue("wireframe", isWireframe)) {
                wirePrimitives.push_back(primitive);
            }
            else if (needToSplit(*primitive)) {
                primitivesToFix.push_back(primitive);
            }
            else {
                geomPrimitives.push_back(primitive);
                setValidIndices(validIndices, primitive);
            }
        }

        if (!primitivesToFix.empty()) {
            // filter all indices > _maxIndexValue in primitives
            for (unsigned int i = 0; i < primitivesToFix.size(); i++) {
                osg::DrawElements* source = primitivesToFix[i].get();
                osg::DrawElements* large = removeLargeIndices(dynamic_cast<osg::DrawElementsUInt*>(source));
                reportedPrimitives.push_back(large);
                geomPrimitives.push_back(source);
                setValidIndices(validIndices, source);
            }
        }

        // keep wireframe data associated to the current solid geometry
        extractWireframePrimitive(wirePrimitives, validIndices, geomPrimitives, reportedPrimitives);

        geometry.setPrimitiveSetList(geomPrimitives);
        if (!reportedPrimitives.empty()) {
            reportGeometry = osg::clone(&geometry, osg::CopyOp::DEEP_COPY_ARRAYS |
                                                   osg::CopyOp::DEEP_COPY_USERDATA);
            reportGeometry->setPrimitiveSetList(reportedPrimitives);
            return reportGeometry.release();
        }
        return 0;
    }

    void setValidIndices(std::set<unsigned int>& indices, const osg::DrawElements* primitive) const {
        for(unsigned int j = 0 ; j < primitive->getNumIndices() ; ++ j) {
            indices.insert(primitive->index(j));
        }
    }

    osg::DrawElements* removeLargeIndices(osg::DrawElementsUInt* source) const {
        osg::DrawElementsUInt* large = new osg::DrawElementsUInt(source->getMode());
        unsigned int primitive_size = 0;
        switch(source->getMode()) {
            case osg::PrimitiveSet::POINTS:
                primitive_size = 1;
            break;
            case osg::PrimitiveSet::LINES:
                primitive_size = 2;
            break;
            case osg::PrimitiveSet::TRIANGLES:
                primitive_size = 3;
            break;
        }

        for (int id = source->getNumPrimitives() - 1; id >= 0; -- id) {
            const unsigned int arrayIndex = id * primitive_size;
            for(unsigned int i = 0 ; i < primitive_size ; ++ i) {
                if(source->index(arrayIndex + i) > _maxIndexToSplit) {
                    // add primitive in the large DrawElements
                    for(unsigned int j = 0 ; j < primitive_size ; ++ j) {
                        large->addElement(source->index(arrayIndex + j));
                    }
                    // remove primitive from source DrawElements
                    for(int j = primitive_size - 1 ; j >= 0 ; -- j) {
                        source->erase(source->begin() + arrayIndex + j);
                    }
                    break; // skip to next primitive
                }
            }
        }
        return large;
    }

    // keep wireframe data associated to the solid geometry
    void extractWireframePrimitive(osg::Geometry::PrimitiveSetList& lines,
                                   const std::set<unsigned int>& indices,
                                   osg::Geometry::PrimitiveSetList& primitives,
                                   osg::Geometry::PrimitiveSetList& reported) const {
        if(indices.empty()) {
            return;
        }

        for(unsigned int i = 0 ; i < lines.size() ; ++ i) {
            const osg::DrawElementsUInt* line = dynamic_cast<osg::DrawElementsUInt*>(lines[i].get());
            if(!line || line->getMode() != osg::PrimitiveSet::LINES) {
                osg::notify(osg::INFO) << "Primitive with bad mode flagged as wireframe. Skipping."
                                       << std::endl;
            }
            osg::ref_ptr<osg::DrawElementsUInt> small = new osg::DrawElementsUInt(osg::PrimitiveSet::LINES);
            osg::ref_ptr<osg::DrawElementsUInt> large = new osg::DrawElementsUInt(osg::PrimitiveSet::LINES);

            for (unsigned int id = 0 ; id < line->getNumPrimitives() ; ++ id) {
                unsigned int arrayIndex = id * 2;
                unsigned int a = line->index(arrayIndex),
                             b = line->index(arrayIndex + 1);

                if(indices.find(a) != indices.end() && indices.find(b) != indices.end()) {
                    small->addElement(a);
                    small->addElement(b);
                }
                else {
                    large->addElement(a);
                    large->addElement(b);
                }
            }

            if(small->size()) {
                small->setUserValue("wireframe", true);
                primitives.push_back(small);
            }

            if(large->size()) {
                large->setUserValue("wireframe", true);
                reported.push_back(large);
            }
        }
    }

public:
    const unsigned int _maxIndexToSplit;
    bool _disablePostTransform;
    GeometryList _geometryList;
};



class GeometrySplitterVisitor : public GeometryUniqueVisitor {
public:
    typedef std::vector< osg::ref_ptr<osg::Geometry> > GeometryList;

    GeometrySplitterVisitor(unsigned int maxIndexValue = 65535, bool disablePostTransform=false):
        GeometryUniqueVisitor("GeometrySplitterVisitor"),
        _maxIndexValue(maxIndexValue),
        _disablePostTransform(disablePostTransform)
    {}

    void apply(osg::Geode& geode) {
        GeometryUniqueVisitor::apply(geode);
        GeometryList remapped;
        for(unsigned int i = 0 ; i < geode.getNumDrawables() ; ++ i) {
            osg::Geometry* geometry = geode.getDrawable(i)->asGeometry();
            if(geometry) {
                std::map<osg::Geometry*, GeometryList>::iterator lookup = _split.find(geometry);
                if(lookup != _split.end() && !lookup->second.empty()) {
                    remapped.insert(remapped.end(), lookup->second.begin(), lookup->second.end());
                }
            }
        }
        // remove all drawables
        geode.removeDrawables(0, geode.getNumDrawables());
        for(unsigned int i = 0 ; i < remapped.size() ; ++ i) {
            geode.addDrawable(remapped[i].get());
        }
    }

    void apply(osg::Geometry& geometry) {
        GeometryIndexSplitter splitter(_maxIndexValue, _disablePostTransform);
        splitter.split(geometry);
        setProcessed(&geometry, splitter._geometryList);
    }

protected:
    bool isProcessed(osg::Geometry* node) {
        return _split.find(node) != _split.end();
    }

    void setProcessed(osg::Geometry* node, const GeometryList& list) {
        _split.insert(std::pair<osg::Geometry*, GeometryList>(node, GeometryList(list)));
    }

    unsigned int _maxIndexValue;
    std::map<osg::Geometry*, GeometryList> _split;
    bool _disablePostTransform;
};

#endif
