# Used by particleCurvesRI.mel
# M.Kesson
# 11 March 2016
import maya.cmds as mc
import os
from maya_proj_utils import MayaProjUtils
import random
  
class Point:
    def __init__(self, xyz, width=0.02):
        self.__coord = xyz
        self.__width = width
    def get_xyz(self):
        return self.__coord
    def get_x(self):
        return self.__coord[0]
    def get_y(self):
        return self.__coord[1]
    def get_z(self):
        return self.__coord[2]
    def get_width(self):
        return self.__width
            
class PointDB:
    MIN =  10000000.0
    MAX = -10000000.0
    def __init__(self):
        self.data = []
        
    def add(self, pnt_index, position, width):
        xyz = position[0:3]
        pnt = Point(xyz, width)
        if len(self.data) > pnt_index:    
            self.data[pnt_index].append(pnt)
        else:
            pnts_list = [pnt]
            self.data.append(pnts_list)
        
    def get_points(self, index):
        return self.data[index]
        
    def get_first_coord(self, index):
        pnt = self.data[index][0]
        return pnt.get_xyz()
        
    def get_last_coord(self, index):
        pnts_list = self.get_points(index)
        last = len(pnts_list) - 1
        pnt = pnts_list[last]
        return pnt.get_xyz()
        
    def get_numpoints(self):
        return len(self.data)
        
    def get_bbox(self):
        minx = miny = minz = PointDB.MIN
        maxx = maxy = maxz = PointDB.MAX
        for n in range(self.get_numpoints()):
            pnts = self.get_points(n)
            for pnt in pnts:
                x,y,z = pnt.get_xyz()
                minx = min(minx, x)
                miny = min(miny, y)
                minz = min(minz, z)
                maxx = max(maxx, x)
                maxy = max(maxy, y)
                maxz = max(maxz, z)
        return [minx,miny,minz,maxx,maxy,maxz]
  
#-----------------------------------------
# Instances of this class are created only on frame 1
class ParticleCurves(PointDB):
    def __init__(self, tnode, rounded, attrname):
        PointDB.__init__(self);
        self.geoname = tnode
        self.rounded = rounded
        self.attrname = attrname
        if len(self.attrname) == 0:
            self.attrname = 'probability'
        self.utils = MayaProjUtils()
        self.archivepath = self.utils.getRIB_ArchivePath()
        if os.path.exists(self.archivepath) == False:
            os.mkdir(self.archivepath)
        self.begin_frame = self.utils.getAnimationStart()
        self.end_frame = self.utils.getAnimationEnd()
        self.scenename = self.utils.getSceneName()
        
    def update(self, width):
        pnum = mc.particle(self.geoname, q = True, count = True)
        for n in range(pnum):
            pname = self.geoname + ".pt[%d]" % n
            pos = mc.getParticleAttr(pname,at = 'position')
            self.add(n, pos, width)
            
    def bakeCurves(self, width, ):
        self.update(width)
        ribpath = self.make_archive_path()
        fileid = open(ribpath, 'w')
        fileid.write(self.get_bbox_str())
        fileid.write('AttributeBegin\n')
        if self.rounded:
            fileid.write('Attribute "dice" "hair" [1]\n')
        fileid.write('Basis "catmull-rom" 1 "catmull-rom" 1\n')
        for n in range(self.get_numpoints()):
            pnts = self.get_points(n)
            numcvs = len(pnts)
            if numcvs < 4:
                continue
            # Write a user attribute of the form,
            # Attribute "user" "float some_name" [0.665]
            random.seed(n)
            attrval = random.uniform(0.0, 1.0)
            fileid.write('Attribute "user" "float %s" [%1.4f]\n' % (self.attrname,attrval))
            
            fileid.write('Curves "cubic" [%d] "nonperiodic" "P" [\n' % (numcvs + 2))
            x,y,z = self.get_first_coord(n)
            fileid.write('%1.3f %1.3f %1.3f\n' % (x,y,z))
            for pnt in pnts:
                x,y,z = pnt.get_xyz()
                fileid.write('%1.3f %1.3f %1.3f\n' % (x,y,z))
            x,y,z = self.get_last_coord(n)
            fileid.write('%1.3f %1.3f %1.3f]\n' % (x,y,z))
            fileid.write('"width" [\n')
            for pnt in pnts:
                fileid.write('%1.4f\n' % pnt.get_width())
            fileid.write(']\n')    
        fileid.write('AttributeEnd\n')
        fileid.close()
            
    # __________________________________________________
    def make_archive_path(self):
        if self.utils != None:
            frame = self.utils.getCurrentTime()
        else:
            frame = mc.currentTime(q = True)
        frame_padding = '%0*d' % (4, frame)
        
        ribname = '%s.%s.rib' % (self.geoname, frame_padding)
        path = os.path.join(self.archivepath, self.scenename)
        if os.path.exists(path) == False:
            os.mkdir(path)
        path = os.path.join(path, self.geoname)
        if os.path.exists(path) == False:
            os.mkdir(path)
        fullpath = os.path.join(path, ribname)
        return fullpath
    # __________________________________________________
    # Puts some useful information at the beginning of the rib archive file.    
    def get_bbox_str(self):
        x,y,z,X,Y,Z = self.get_bbox()
        return '#bbox: %1.4f %1.4f %1.4f %1.4f %1.4f %1.4f \n' % (x,y,z,X,Y,Z)
  
  
#            elif self.curve_type == 'linear' and numcvs >= 2:
#                fileid.write('Curves "linear" [%d] "nonperiodic" "P" [' % numcvs)
#                for pnt in pnts:
#                    x,y,z = pnt.get_xyz()
#                    fileid.write('%1.3f %1.3f %1.3f\n' % (x,y,z))
#                fileid.write('] "constantwidth" [%1.4f]\n' % width)
#