public class LogisticRegression extends ProbabilisticClassifier<Vector,LogisticRegression,LogisticRegressionModel> implements Logging
| Constructor and Description |
|---|
LogisticRegression() |
LogisticRegression(java.lang.String uid) |
| Modifier and Type | Method and Description |
|---|---|
void |
checkThresholdConsistency()
If
threshold and thresholds are both set, ensures they are consistent. |
LogisticRegression |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
LogisticRegression |
setElasticNetParam(double value)
Set the ElasticNet mixing parameter.
|
LogisticRegression |
setFitIntercept(boolean value)
Whether to fit an intercept term.
|
LogisticRegression |
setMaxIter(int value)
Set the maximum number of iterations.
|
LogisticRegression |
setRegParam(double value)
Set the regularization parameter.
|
LogisticRegression |
setStandardization(boolean value)
Whether to standardize the training features before fitting the model.
|
LogisticRegression |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegression |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
LogisticRegression |
setTol(double value)
Set the convergence tolerance of iterations.
|
protected LogisticRegressionModel |
train(DataFrame dataset)
Train a model using the given dataset and parameters.
|
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
void |
validateParams() |
setProbabilityColsetRawPredictionColextractLabeledPoints, fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchematransformSchemaclone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitinitializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarningclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParamstoStringpublic LogisticRegression(java.lang.String uid)
public LogisticRegression()
public java.lang.String uid()
Identifiableuid in interface Identifiablepublic LogisticRegression setRegParam(double value)
value - (undocumented)public LogisticRegression setElasticNetParam(double value)
value - (undocumented)public LogisticRegression setMaxIter(int value)
value - (undocumented)public LogisticRegression setTol(double value)
value - (undocumented)public LogisticRegression setFitIntercept(boolean value)
value - (undocumented)public LogisticRegression setStandardization(boolean value)
value - (undocumented)public LogisticRegression setThreshold(double value)
If the estimated probability of class label 1 is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p)).
When setThreshold() is called, any user-set value for thresholds will be cleared.
If both threshold and thresholds are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
value - (undocumented)public double getThreshold()
If threshold is set, returns that value.
Otherwise, if thresholds is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1)).
Otherwise, returns {@link threshold} default value.
@group getParam
@throws IllegalArgumentException if {@link thresholds} is set to an array of length other than 2.public LogisticRegression setThresholds(double[] value)
Note: When setThresholds() is called, any user-set value for threshold will be cleared.
If both threshold and thresholds are set in a ParamMap, then they must be
equivalent.
setThresholds in class ProbabilisticClassifier<Vector,LogisticRegression,LogisticRegressionModel>value - (undocumented)public double[] getThresholds()
If thresholds is set, return its value.
Otherwise, if threshold is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
protected LogisticRegressionModel train(DataFrame dataset)
Predictorfit() to avoid dealing with schema validation
and copying parameters into the model.
train in class Predictor<Vector,LogisticRegression,LogisticRegressionModel>dataset - Training datasetpublic LogisticRegression copy(ParamMap extra)
Paramscopy in interface Paramscopy in class Predictor<Vector,LogisticRegression,LogisticRegressionModel>extra - (undocumented)defaultCopy()public void checkThresholdConsistency()
threshold and thresholds are both set, ensures they are consistent.java.lang.IllegalArgumentException - if threshold and thresholds are not equivalentpublic void validateParams()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema - input schemafitting - whether this is in fittingfeaturesDataType - SQL DataType for FeaturesType.
E.g., VectorUDT for vector features.