Skip to content

Commit

Permalink
Remove raiseErrors from THTensor functions, have THStorage functions …
Browse files Browse the repository at this point in the history
…take an error_buffer to return a proper error message while being able to handle memory management correctly from calling function.
  • Loading branch information
gchanan authored and soumith committed Jun 11, 2017
1 parent 52109da commit 17bd4b8
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 71 deletions.
32 changes: 14 additions & 18 deletions lib/TH/THStorage.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ THDescBuff THLongStorage_sizeDesc(const THLongStorage *size) {
return buf;
}

TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement)
THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement)
{
ptrdiff_t total_size = (size->size > 0 ? 1 : 0);
ptrdiff_t dim_infer = -1;
Expand Down Expand Up @@ -66,7 +66,8 @@ TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t
return copy;
}

TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB, int raiseErrors) {
int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB,
char *error_buffer, int buffer_len) {
THArgCheck(sizesA != NULL, 1, "sizesA must not be null");
THArgCheck(sizesB != NULL, 2, "sizesB must not be null");
THArgCheck(dimsA, 1, "Can't expand empty tensor a");
Expand All @@ -85,10 +86,8 @@ TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long di
expandedSizes[i] = THMax(sizeA, sizeB);
} else {
THFree(expandedSizes);
if (raiseErrors) {
THError("The size of tensor a (%ld) must match the size of tensor b (%ld) at "
"non-singleton dimension %ld.", sizeA, sizeB, i);
}
snprintf(error_buffer, buffer_len, "The size of tensor a (%ld) must match the size of tensor b (%ld) at "
"non-singleton dimension %ld.", sizeA, sizeB, i);
return -1;
}
}
Expand All @@ -98,7 +97,8 @@ TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long di
return 0;
}

TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors) {
int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims,
char *error_buffer, int buffer_len) {
THArgCheck(n > 0, 2, "n must be greater than 0");
THArgCheck(sizes != NULL, 1, "sizes must not be null");
THArgCheck(dims != NULL, 1, "dims must not be null");
Expand All @@ -122,10 +122,8 @@ TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes,
expandedSizes[ i ] = THMax(expandedSizes[ i ], size);
} else {
THFree(expandedSizes);
if (raiseErrors) {
THError("The size of tensor %i (%ld) must match the expanded size of tensor (%ld) at "
"non-singleton dimension %ld.", j, size, expandedSizes[ i ], i);
}
snprintf(error_buffer, buffer_len, "The size of tensor %i (%ld) must match the expanded size"
"of tensor (%ld) at non-singleton dimension %ld.", j, size, expandedSizes[ i ], i);
return -1;
}
}
Expand All @@ -136,9 +134,9 @@ TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes,
return 0;
}

TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim,
THLongStorage *sizes, long **expandedSizes, long **expandedStrides,
int raiseErrors) {
int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim,
THLongStorage *sizes, long **expandedSizes, long **expandedStrides,
char *error_buffer, int buffer_len) {
ptrdiff_t ndim = THLongStorage_size(sizes);

long *expandedSizesCalc = THAlloc(sizeof(long)*ndim);
Expand All @@ -159,10 +157,8 @@ TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStri
} else {
THFree(expandedSizesCalc);
THFree(expandedStridesCalc);
if (raiseErrors) {
THError("The expanded size of the tensor (%d) must match the existing size (%d) at "
"non-singleton dimension %ld.", targetSize, size, i);
}
snprintf(error_buffer, buffer_len, "The expanded size of the tensor (%d) must match the existing size (%d) at "
"non-singleton dimension %ld.", targetSize, size, i);
return -1;
}
}
Expand Down
10 changes: 6 additions & 4 deletions lib/TH/THStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ typedef struct {
TH_API THDescBuff THLongStorage_sizeDesc(const THLongStorage *size);
TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement);

// Given the sizes of {2,N} tensors, write out the size when the tensors are expanded together
TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB, int raiseErrors);
TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors);
// Given the sizes of {2,N} tensors, write out the size when the tensors are expanded together.
TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA,
long *sizesB, long dimsB, char *error_buffer, int buffer_len);
TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims,
char *error_buffer, int buffer_len);

TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim,
THLongStorage *sizes, long **expandedSizes, long **expandedStrides,
int raiseErrors);
char *error_buffer, int buffer_len);

#endif
84 changes: 38 additions & 46 deletions lib/TH/generic/THTensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -287,65 +287,63 @@ void THTensor_(resize5d)(THTensor *self, long size0, long size1, long size2, lon

THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes) {
THTensor *result = THTensor_(new)();
THTensor_(expand)(result, tensor, sizes, 1);
THTensor_(expand)(result, tensor, sizes);
return result;
}

int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes, int raiseErrors) {
if (raiseErrors) {
THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor");
THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1,
"the number of sizes provided must be greater or equal to the "
"number of dimensions in the tensor");
} else if (THLongStorage_size(sizes) < THTensor_(nDimension)(tensor)) {
return -1;
}
void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes) {
THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor");
THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1,
"the number of sizes provided must be greater or equal to the "
"number of dimensions in the tensor");

long *expandedSizes;
long *expandedStrides;
char error_buffer[1024];
int ret =
THLongStorage_inferExpandGeometry(tensor->size, tensor->stride, THTensor_(nDimension)(tensor), sizes, &expandedSizes, &expandedStrides, raiseErrors);
THLongStorage_inferExpandGeometry(tensor->size, tensor->stride, THTensor_(nDimension)(tensor),
sizes, &expandedSizes, &expandedStrides, error_buffer, 1024);

if (ret != 0) {
return ret;
THError(error_buffer);
return;
}
THTensor_(setStorageNd)(r, THTensor_(storage)(tensor), THTensor_(storageOffset)(tensor), THLongStorage_size(sizes), expandedSizes, expandedStrides);

THTensor_(setStorageNd)(r, THTensor_(storage)(tensor), THTensor_(storageOffset)(tensor),
THLongStorage_size(sizes), expandedSizes, expandedStrides);
THFree(expandedSizes);
THFree(expandedStrides);

return 0;
}

int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, int raiseErrors) {
void THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb) {
THArgCheck(THTensor_(nDimension)(opa) > 0, 0, "can't expand empty tensor opa");
THArgCheck(THTensor_(nDimension)(opb) > 0, 0, "can't expand empty tensor opb");

THLongStorage *sizes = THLongStorage_new();
int ret = THLongStorage_inferSize2(sizes,
opa->size, THTensor_(nDimension)(opa),
opb->size, THTensor_(nDimension)(opb),
raiseErrors);
if(ret != 0) {
char error_buffer[1024];
int ret =THLongStorage_inferSize2(sizes,
opa->size, THTensor_(nDimension)(opa),
opb->size, THTensor_(nDimension)(opb),
error_buffer, 1024);
if (ret != 0) {
THLongStorage_free(sizes);
return ret;
THError(error_buffer);
return;
}

ret = THTensor_(expand)(ra, opa, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
ret = THTensor_(expand)(rb, opb, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
THTensor_(expand)(ra, opa, sizes);
THTensor_(expand)(rb, opb, sizes);

THLongStorage_free(sizes);
return 0;
}

int THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa, THTensor *opb, THTensor *opc, int raiseErrors) {
void THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa, THTensor *opb, THTensor *opc) {
THArgCheck(THTensor_(nDimension)(opa) > 0, 0, "can't expand empty tensor opa");
THArgCheck(THTensor_(nDimension)(opb) > 0, 0, "can't expand empty tensor opb");
THArgCheck(THTensor_(nDimension)(opc) > 0, 0, "can't expand empty tensor opc");

const int op_n = 3;
long **op_sizes = THAlloc(sizeof(long**)*op_n);
long *op_dims = THAlloc(sizeof(long*)*op_n);
long *op_sizes[3];
long op_dims[3];

op_sizes[ 0 ] = opa->size;
op_sizes[ 1 ] = opb->size;
Expand All @@ -355,33 +353,27 @@ int THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa,
op_dims[ 2 ] = opc->nDimension;

THLongStorage *sizes = THLongStorage_new();
char error_buffer[1024];
int ret = THLongStorage_inferSizeN(sizes,
op_n,
3,
op_sizes,
op_dims,
raiseErrors);
error_buffer,
1024);

if(ret != 0) {
THLongStorage_free(sizes);
THFree(op_dims);
THFree(op_sizes);
return ret;
THError(error_buffer);
return;
}

ret = THTensor_(expand)(ra, opa, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
ret = THTensor_(expand)(rb, opb, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
ret = THTensor_(expand)(rc, opc, sizes, raiseErrors);
THAssert(ret == 0); // since we inferred this already, it must be valid
THTensor_(expand)(ra, opa, sizes);
THTensor_(expand)(rb, opb, sizes);
THTensor_(expand)(rc, opc, sizes);

THLongStorage_free(sizes);
THFree(op_dims);
THFree(op_sizes);
return 0;
}


void THTensor_(set)(THTensor *self, THTensor *src)
{
if(self != src)
Expand Down
6 changes: 3 additions & 3 deletions lib/TH/generic/THTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ TH_API THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long siz
TH_API THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size);
TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size);

TH_API int THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size, int raiseErrors);
TH_API int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, int raiseErrors);
TH_API int THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa, THTensor *opb, THTensor *opc, int raiseErrors);
TH_API void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size);
TH_API void THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb);
TH_API void THTensor_(expand3)(THTensor *ra, THTensor *rb, THTensor *rc, THTensor *opa, THTensor *opb, THTensor *opc);

TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride);
TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src);
Expand Down

0 comments on commit 17bd4b8

Please sign in to comment.