public class LinearRegression extends Predictor<FeaturesType,Learner,M> implements DefaultParamsWritable, Logging
The learning objective is to minimize the specified loss function, with regularization. This supports two kinds of loss: - squaredError (a.k.a squared loss) - huber (a hybrid of squared error for relatively small errors and absolute error for relatively large ones, and we estimate the scale parameter from training data)
This supports multiple types of regularization: - none (a.k.a. ordinary least squares) - L2 (ridge regression) - L1 (Lasso) - L2 + L1 (elastic net)
The squared error objective function is:
$$ \begin{align} \min_{w}\frac{1}{2n}{\sum_{i=1}^n(X_{i}w - y_{i})^{2} + \lambda\left[\frac{1-\alpha}{2}{||w||_{2}}^{2} + \alpha{||w||_{1}}\right]} \end{align} $$
The huber objective function is:
$$ \begin{align} \min_{w, \sigma}\frac{1}{2n}{\sum_{i=1}^n\left(\sigma + H_m\left(\frac{X_{i}w - y_{i}}{\sigma}\right)\sigma\right) + \frac{1}{2}\lambda {||w||_2}^2} \end{align} $$
where
$$ \begin{align} H_m(z) = \begin{cases} z^2, & \text {if } |z| < \epsilon, \\ 2\epsilon|z| - \epsilon^2, & \text{otherwise} \end{cases} \end{align} $$
Note: Fitting with huber loss only supports none and L2 regularization.
Constructor and Description |
---|
LinearRegression() |
LinearRegression(String uid) |
Modifier and Type | Method and Description |
---|---|
static IntParam |
aggregationDepth() |
static Params |
clear(Param<?> param) |
LinearRegression |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
static DoubleParam |
elasticNetParam() |
static DoubleParam |
epsilon() |
DoubleParam |
epsilon()
The shape parameter to control the amount of robustness.
|
static String |
explainParam(Param<?> param) |
static String |
explainParams() |
static ParamMap |
extractParamMap() |
static ParamMap |
extractParamMap(ParamMap extra) |
static Param<String> |
featuresCol() |
static M |
fit(Dataset<?> dataset) |
static M |
fit(Dataset<?> dataset,
ParamMap paramMap) |
static scala.collection.Seq<M> |
fit(Dataset<?> dataset,
ParamMap[] paramMaps) |
static M |
fit(Dataset<?> dataset,
ParamPair<?> firstParamPair,
ParamPair<?>... otherParamPairs) |
static M |
fit(Dataset<?> dataset,
ParamPair<?> firstParamPair,
scala.collection.Seq<ParamPair<?>> otherParamPairs) |
static BooleanParam |
fitIntercept() |
static <T> scala.Option<T> |
get(Param<T> param) |
static int |
getAggregationDepth() |
static <T> scala.Option<T> |
getDefault(Param<T> param) |
static double |
getElasticNetParam() |
static double |
getEpsilon() |
double |
getEpsilon() |
static String |
getFeaturesCol() |
static boolean |
getFitIntercept() |
static String |
getLabelCol() |
static String |
getLoss() |
static int |
getMaxIter() |
static <T> T |
getOrDefault(Param<T> param) |
static Param<Object> |
getParam(String paramName) |
static String |
getPredictionCol() |
static double |
getRegParam() |
static String |
getSolver() |
static boolean |
getStandardization() |
static double |
getTol() |
static String |
getWeightCol() |
static <T> boolean |
hasDefault(Param<T> param) |
static boolean |
hasParam(String paramName) |
static boolean |
isDefined(Param<?> param) |
static boolean |
isSet(Param<?> param) |
static Param<String> |
labelCol() |
static LinearRegression |
load(String path) |
static Param<String> |
loss() |
Param<String> |
loss()
The loss function to be optimized.
|
static int |
MAX_FEATURES_FOR_NORMAL_SOLVER()
When using
LinearRegression.solver == "normal", the solver must limit the number of
features to at most this number. |
static IntParam |
maxIter() |
static Param<?>[] |
params() |
static Param<String> |
predictionCol() |
static DoubleParam |
regParam() |
static void |
save(String path) |
static <T> Params |
set(Param<T> param,
T value) |
LinearRegression |
setAggregationDepth(int value)
Suggested depth for treeAggregate (greater than or equal to 2).
|
LinearRegression |
setElasticNetParam(double value)
Set the ElasticNet mixing parameter.
|
LinearRegression |
setEpsilon(double value)
Sets the value of param
epsilon . |
static Learner |
setFeaturesCol(String value) |
LinearRegression |
setFitIntercept(boolean value)
Set if we should fit the intercept.
|
static Learner |
setLabelCol(String value) |
LinearRegression |
setLoss(String value)
Sets the value of param
loss . |
LinearRegression |
setMaxIter(int value)
Set the maximum number of iterations.
|
static Learner |
setPredictionCol(String value) |
LinearRegression |
setRegParam(double value)
Set the regularization parameter.
|
LinearRegression |
setSolver(String value)
Set the solver algorithm used for optimization.
|
LinearRegression |
setStandardization(boolean value)
Whether to standardize the training features before fitting the model.
|
LinearRegression |
setTol(double value)
Set the convergence tolerance of iterations.
|
LinearRegression |
setWeightCol(String value)
Whether to over-/under-sample training instances according to the given weights in weightCol.
|
static Param<String> |
solver() |
Param<String> |
solver()
The solver algorithm for optimization.
|
static BooleanParam |
standardization() |
static DoubleParam |
tol() |
static String |
toString() |
static StructType |
transformSchema(StructType schema) |
String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
static Param<String> |
weightCol() |
static MLWriter |
write() |
fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
getRegParam, regParam
elasticNetParam, getElasticNetParam
getMaxIter, maxIter
fitIntercept, getFitIntercept
getStandardization, standardization
getWeightCol, weightCol
aggregationDepth, getAggregationDepth
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
write
save
initializeLogging, initializeLogIfNecessary, initializeLogIfNecessary, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
public LinearRegression(String uid)
public LinearRegression()
public static LinearRegression load(String path)
public static int MAX_FEATURES_FOR_NORMAL_SOLVER()
LinearRegression.solver
== "normal", the solver must limit the number of
features to at most this number. The entire covariance matrix X^T^X will be collected
to the driver. This limit helps prevent memory overflow errors.public static String toString()
public static Param<?>[] params()
public static String explainParam(Param<?> param)
public static String explainParams()
public static final boolean isSet(Param<?> param)
public static final boolean isDefined(Param<?> param)
public static boolean hasParam(String paramName)
public static Param<Object> getParam(String paramName)
public static final <T> scala.Option<T> get(Param<T> param)
public static final <T> T getOrDefault(Param<T> param)
public static final <T> scala.Option<T> getDefault(Param<T> param)
public static final <T> boolean hasDefault(Param<T> param)
public static final ParamMap extractParamMap()
public static M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, scala.collection.Seq<ParamPair<?>> otherParamPairs)
public static M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, ParamPair<?>... otherParamPairs)
public static final Param<String> labelCol()
public static final String getLabelCol()
public static final Param<String> featuresCol()
public static final String getFeaturesCol()
public static final Param<String> predictionCol()
public static final String getPredictionCol()
public static Learner setLabelCol(String value)
public static Learner setFeaturesCol(String value)
public static Learner setPredictionCol(String value)
public static M fit(Dataset<?> dataset)
public static StructType transformSchema(StructType schema)
public static final DoubleParam regParam()
public static final double getRegParam()
public static final DoubleParam elasticNetParam()
public static final double getElasticNetParam()
public static final IntParam maxIter()
public static final int getMaxIter()
public static final DoubleParam tol()
public static final double getTol()
public static final BooleanParam fitIntercept()
public static final boolean getFitIntercept()
public static final BooleanParam standardization()
public static final boolean getStandardization()
public static final Param<String> weightCol()
public static final String getWeightCol()
public static final String getSolver()
public static final IntParam aggregationDepth()
public static final int getAggregationDepth()
public static final String getLoss()
public static final Param<String> solver()
public static final Param<String> loss()
public static final DoubleParam epsilon()
public static double getEpsilon()
public static void save(String path) throws java.io.IOException
java.io.IOException
public static MLWriter write()
public String uid()
Identifiable
uid
in interface Identifiable
public LinearRegression setRegParam(double value)
value
- (undocumented)public LinearRegression setFitIntercept(boolean value)
value
- (undocumented)public LinearRegression setStandardization(boolean value)
value
- (undocumented)public LinearRegression setElasticNetParam(double value)
Note: Fitting with huber loss only supports None and L2 regularization, so throws exception if this param is non-zero value.
value
- (undocumented)public LinearRegression setMaxIter(int value)
value
- (undocumented)public LinearRegression setTol(double value)
value
- (undocumented)public LinearRegression setWeightCol(String value)
value
- (undocumented)public LinearRegression setSolver(String value)
LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER
.
- "auto" (default) means that the solver algorithm is selected automatically.
The Normal Equations solver will be used when possible, but this will automatically fall
back to iterative optimization methods when needed.
Note: Fitting with huber loss doesn't support normal solver, so throws exception if this param was set with "normal".
value
- (undocumented)public LinearRegression setAggregationDepth(int value)
value
- (undocumented)public LinearRegression setLoss(String value)
loss
.
Default is "squaredError".
value
- (undocumented)public LinearRegression setEpsilon(double value)
epsilon
.
Default is 1.35.
value
- (undocumented)public LinearRegression copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Predictor<Vector,LinearRegression,LinearRegressionModel>
extra
- (undocumented)public DoubleParam epsilon()
public double getEpsilon()
public Param<String> loss()
public Param<String> solver()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.