public class GBTClassificationModel extends ProbabilisticClassificationModel<Vector,GBTClassificationModel> implements GBTClassifierParams, TreeEnsembleModel<DecisionTreeRegressionModel>, MLWritable, scala.Serializable
param: _trees Decision trees in the ensemble. param: _treeWeights Weights for the decision trees in the ensemble.
| Constructor and Description |
|---|
GBTClassificationModel(String uid,
DecisionTreeRegressionModel[] _trees,
double[] _treeWeights)
Construct a GBTClassificationModel
|
| Modifier and Type | Method and Description |
|---|---|
GBTClassificationModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
double[] |
evaluateEachIteration(Dataset<?> dataset)
Method to compute error or loss for every iteration of gradient boosting.
|
Vector |
featureImportances()
Estimate of the importance of each feature.
|
int |
getNumTrees()
Number of trees in ensemble
|
static GBTClassificationModel |
load(String path) |
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
int |
numTrees()
Deprecated.
Use
getNumTrees instead. This method will be removed in 3.0.0. |
double |
predict(Vector features)
Predict label for the given features.
|
static MLReader<GBTClassificationModel> |
read() |
String |
toString()
Summary of the model
|
DecisionTreeRegressionModel[] |
trees()
Trees in this ensemble.
|
double[] |
treeWeights()
Weights for each tree, zippable with
trees |
String |
uid()
An immutable unique ID for the object and its derivatives.
|
MLWriter |
write()
Returns an
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, setProbabilityCol, setThresholds, transformsetRawPredictionColsetFeaturesCol, setPredictionCol, transformSchematransform, transform, transformgetLossType, getOldLossType, lossTypegetOldBoostingStrategy, getValidationTol, setMaxIter, setStepSize, stepSize, validationTolfeatureSubsetStrategy, getFeatureSubsetStrategy, getOldStrategy, getSubsamplingRate, setFeatureSubsetStrategy, setSubsamplingRate, subsamplingRatecacheNodeIds, getCacheNodeIds, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getOldStrategy, maxBins, maxDepth, maxMemoryInMB, minInfoGain, minInstancesPerNode, setCacheNodeIds, setCheckpointInterval, setMaxBins, setMaxDepth, setMaxMemoryInMB, setMinInfoGain, setMinInstancesPerNode, setSeedvalidateAndTransformSchemagetLabelCol, labelColfeaturesCol, getFeaturesColgetPredictionCol, predictionColclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwncheckpointInterval, getCheckpointIntervalgetMaxIter, maxItergetStepSizegetValidationIndicatorCol, validationIndicatorColgetImpurity, getOldImpurity, impurity, setImpurityjavaTreeWeights, toDebugString, totalNumNodessavevalidateAndTransformSchemagetRawPredictionCol, rawPredictionColgetProbabilityCol, probabilityColgetThresholds, thresholdsinitializeLogging, initializeLogIfNecessary, initializeLogIfNecessary, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarningpublic GBTClassificationModel(String uid,
DecisionTreeRegressionModel[] _trees,
double[] _treeWeights)
_trees - Decision trees in the ensemble._treeWeights - Weights for the decision trees in the ensemble.uid - (undocumented)public static MLReader<GBTClassificationModel> read()
public static GBTClassificationModel load(String path)
public String uid()
Identifiableuid in interface Identifiablepublic int numFeatures()
PredictionModelnumFeatures in class PredictionModel<Vector,GBTClassificationModel>public int numClasses()
ClassificationModelnumClasses in class ClassificationModel<Vector,GBTClassificationModel>public DecisionTreeRegressionModel[] trees()
TreeEnsembleModeltrees in interface TreeEnsembleModel<DecisionTreeRegressionModel>public int getNumTrees()
public double[] treeWeights()
TreeEnsembleModeltreestreeWeights in interface TreeEnsembleModel<DecisionTreeRegressionModel>public double predict(Vector features)
ClassificationModeltransform() and output predictionCol.
This default implementation for classification predicts the index of the maximum value
from predictRaw().
predict in class ClassificationModel<Vector,GBTClassificationModel>features - (undocumented)public int numTrees()
getNumTrees instead. This method will be removed in 3.0.0.public GBTClassificationModel copy(ParamMap extra)
ParamsdefaultCopy().copy in interface Paramscopy in class Model<GBTClassificationModel>extra - (undocumented)public String toString()
TreeEnsembleModeltoString in interface TreeEnsembleModel<DecisionTreeRegressionModel>toString in interface IdentifiabletoString in class Objectpublic Vector featureImportances()
Each feature's importance is the average of its importance across all trees in the ensemble The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) and follows the implementation from scikit-learn.
See DecisionTreeClassificationModel.featureImportances
public double[] evaluateEachIteration(Dataset<?> dataset)
dataset - Dataset for validation.public MLWriter write()
MLWritableMLWriter instance for this ML instance.write in interface MLWritable