/* 
 CutrColorByRayType.cpp
 RfM22.0
*/
#include <RixPattern.h> 
#include <RixShadingUtils.h>
#include <RixPredefinedStrings.hpp>
#include <RixBxdfLobe.h>
  
class CutrColorByRayType : public RixPattern {
public:
    CutrColorByRayType();
    virtual ~CutrColorByRayType() { }
    virtual int Init(RixContext &, RtUString const pluginpath);
    virtual RixSCParamInfo const *GetParamTable();
    virtual void Synchronize(RixContext&, RixSCSyncMsg, RixParameterList const*) { }
    virtual int CreateInstanceData(RixContext&, RtUString const, RixParameterList const*, InstanceData*) {
        return -1;
        }
    virtual void Finalize(RixContext &) { }
    virtual int ComputeOutputParams(RixShadingContext const *ctx,
                                    RtInt *noutputs, 
                                    RixPattern::OutputSpec **outputs,
                                    RtPointer instanceData,
                                    RixSCParamInfo const *ignored);
    virtual bool Bake2dOutput(RixBakeContext const*, Bake2dSpec&, RtPointer) { return false; }
    virtual bool Bake3dOutput(RixBakeContext const*, Bake3dSpec&, RtPointer) { return false; }
  
    private:
        RixMessages *m_msg;
        RixShadeFunctions *m_shd;  // Shading functions in RixInterfaces.h
        RtColorRGB const m_direct_rgb;
        RtColorRGB const m_reflect_rgb;
        RtInt const m_reflect_active;
        RtColorRGB const m_refract_rgb;
        RtInt const m_refract_active;
    };
  
CutrColorByRayType::CutrColorByRayType():
    m_msg(NULL),
    m_shd(NULL),
    m_direct_rgb(1,1,1),
    m_reflect_rgb(1,0,0),
    m_reflect_active(1),
    m_refract_rgb(0,1,0),
    m_refract_active(1)
    { }
  
int CutrColorByRayType::Init(RixContext &ctx, RtUString const pluginpath) {
    m_msg = (RixMessages*)ctx.GetRixInterface(k_RixMessages);
    m_shd = (RixShadeFunctions*)ctx.GetRixInterface(k_RixShadeFunctions);
  
    // Uncomment the next three lines if a rib Option will be queried.
    //RixRenderState *rstate = (RixRenderState*)ctx.GetRixInterface(k_RixRenderState);
    //RixRenderState::Type optType;
    //RtInt optNumValues, err;
    // Example of using messaging,
    //    m_msg->Info("%f\n", a_float_value);
    return (!m_msg) ? 1 : 0;
    }
  
RixSCParamInfo const *CutrColorByRayType::GetParamTable() {
    static RixSCParamInfo s_ptable[] = {
        // Output
        RixSCParamInfo(RtUString("resultRGB"), k_RixSCColor, k_RixSCOutput),
        // Inputs
        RixSCParamInfo(RtUString("direct_rgb"), k_RixSCColor),
        RixSCParamInfo(RtUString("reflect_rgb"), k_RixSCColor),
        RixSCParamInfo(RtUString("reflect_active"), k_RixSCInteger),
        RixSCParamInfo(RtUString("refract_rgb"), k_RixSCColor),
        RixSCParamInfo(RtUString("refract_active"), k_RixSCInteger),
        RixSCParamInfo() // end of table
        };
    return &s_ptable[0];
    }
  
enum paramIndex {
    k_resultRGB = 0,
    k_direct_rgb,
    k_reflect_rgb,
    k_reflect_active,
    k_refract_rgb,
    k_refract_active
    };
    
int CutrColorByRayType::ComputeOutputParams(RixShadingContext const *ctx,
                                RtInt *noutputs, 
                                RixPattern::OutputSpec **outputs,
                                RtPointer instanceData,
                                RixSCParamInfo const *ignored) {
  
    // Uncomment the next three lines if a rib Attribute will be queried. Note
    // that Rib Options should be queried in the init() method - not here!
    //RixRenderState *rstate = (RixRenderState*)ctx->GetRixInterface(k_RixRenderState);
    //RixRenderState::Type attrType;
    //RtInt attrNumValues, err;
  
    // OUTPUTS BEGIN____________________________________
    // Allocate memory for the OutputSpec data structure.
    RixShadingContext::Allocator pool(ctx);
    OutputSpec *outSpec = pool.AllocForPattern<OutputSpec>(1);
    *outputs = outSpec;
  
    // Allocate memory for each output.
    RtColorRGB    *resultRGB = pool.AllocForPattern<RtColorRGB>(ctx->numPts);
  
    // Connect the output(s) to the OutputSpec.
    *noutputs = 1;
    outSpec[0].paramId = k_resultRGB;
    outSpec[0].detail = k_RixSCVarying;
    outSpec[0].value = resultRGB;
  
    // INPUTS BEGIN____________________________________
    bool varying = true;
    bool uniform = false;
    // Declare a pointer for each input then obtain their values
    // using EvalParam().
    RtColorRGB const *direct_rgb;
    RtColorRGB const *reflect_rgb;
    RtInt const *reflect_activePtr;
    RtColorRGB const *refract_rgb;
    RtInt const *refract_activePtr;
    ctx->EvalParam(k_direct_rgb, -1, &direct_rgb, &m_direct_rgb, varying);
    ctx->EvalParam(k_reflect_rgb, -1, &reflect_rgb, &m_reflect_rgb, varying);
    ctx->EvalParam(k_reflect_active, -1, &reflect_activePtr, &m_reflect_active, uniform);
    ctx->EvalParam(k_refract_rgb, -1, &refract_rgb, &m_refract_rgb, varying);
    ctx->EvalParam(k_refract_active, -1, &refract_activePtr, &m_refract_active, uniform);
  
    // Using the GetBuiltinVar() method we can access an array that identifies for 
    // each of the "numPts" what type of ray triggered the evaluation.
    const int * incidentLobeSampled = NULL;
    ctx->GetBuiltinVar(RixShadingContext::k_incidentLobeSampled, &incidentLobeSampled);
  
    // Assign values to the output(s).
    for(int i = 0; i < ctx->numPts; i++) {
        // Indexing into "incidentLobeSampled" array we can next ask the
        // question, "is the shading point being evaluated by a transmission (aka refraction)
        // ray or a reflection ray.
        RixBXLobeSampled incidentRayType = incidentLobeSampled[i];
         
         if(incidentRayType.GetTransmit()) {
            resultRGB[i] = (*refract_activePtr) ? refract_rgb[i] : direct_rgb[i]; 
            }
        else if(incidentRayType.GetReflect()) {
            resultRGB[i] = (*reflect_activePtr) ? reflect_rgb[i] : direct_rgb[i]; 
            }
        else
            resultRGB[i] = direct_rgb[i]; 
        }
    return 0;
    }
RIX_PATTERNCREATE {
    return new CutrColorByRayType();
    }
RIX_PATTERNDESTROY {
    delete((CutrColorByRayType*)pattern);
    }